frutiemax commited on
Commit
6fa0b52
·
1 Parent(s): 5961f34

Use SGD and text encoder/tokenizer

Browse files
Files changed (3) hide show
  1. rct_diffusion_pipeline.py +70 -7
  2. test_pipeline.py +21 -3
  3. train_model.py +82 -46
rct_diffusion_pipeline.py CHANGED
@@ -12,7 +12,7 @@ import pandas as pd
12
  from tqdm.auto import tqdm
13
 
14
  class RCTDiffusionPipeline(DiffusionPipeline):
15
- def __init__(self, unet, scheduler, vae, latent_size=32, sample_size=256):
16
  super().__init__()
17
 
18
  # dictionnary that keeps the different classes of object description, color1, color2 and color3
@@ -26,11 +26,13 @@ class RCTDiffusionPipeline(DiffusionPipeline):
26
  self.vae = vae
27
  self.latent_size = latent_size
28
  self.sample_size = sample_size
 
 
29
 
30
  # channels for 1 image
31
  self.num_channels = int(self.unet.config.in_channels / 4)
32
  self.load_dictionaries_from_dataset()
33
- self.register_modules(unet=unet, scheduler=scheduler, vae=vae)
34
 
35
  def load_dictionaries_from_dataset(self):
36
  dataset = load_dataset('frutiemax/rct_dataset')
@@ -177,13 +179,72 @@ class RCTDiffusionPipeline(DiffusionPipeline):
177
 
178
  return torch.reshape(noise_batches, (batch_size, 1, self.num_channels*4, self.latent_size, self.latent_size)).to(dtype=torch.float16, device='cuda')
179
 
180
- def __call__(self, object_description : list[list[tuple[str, float]]], color1 : list[list[tuple[str, float]]], \
181
- color2 : list[list[tuple[str, float]]] = None, color3 : list[list[tuple[str, float]]] = None, \
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
  batch_size=1, num_inference_steps=20, generator=torch.manual_seed(torch.random.seed())):
183
 
184
- class_labels = self.get_class_labels(object_description, color1, color2, color3, batch_size).to(device='cuda', dtype=torch.float16)
185
- if class_labels == None:
186
  return None
 
 
187
 
188
  # set the inference steps
189
  self.scheduler.set_timesteps(num_inference_steps)
@@ -196,8 +257,9 @@ class RCTDiffusionPipeline(DiffusionPipeline):
196
  progress_bar.set_description(f'Inference step {epoch}')
197
 
198
  for batch_index in range(batch_size):
 
199
  with torch.no_grad():
200
- noise_residual = self.unet(noise_batches[batch_index], t, encoder_hidden_states=class_labels).sample
201
  previous_noisy_sample = self.scheduler.step(noise_residual, t, noise_batches[batch_index]).prev_sample
202
  noise_batches[batch_index] = previous_noisy_sample
203
  progress_bar.update(1)
@@ -223,6 +285,7 @@ class RCTDiffusionPipeline(DiffusionPipeline):
223
  image = (image.permute(1, 2, 0) * 255).to(torch.uint8).cpu().numpy()
224
  image = (image * 255).round().astype("uint8")
225
  image = Image.fromarray(image)
 
226
  output_images.append(image)
227
 
228
  # for now just return the images
 
12
  from tqdm.auto import tqdm
13
 
14
  class RCTDiffusionPipeline(DiffusionPipeline):
15
+ def __init__(self, unet, scheduler, vae, text_tokenizer, text_encoder, latent_size=32, sample_size=256):
16
  super().__init__()
17
 
18
  # dictionnary that keeps the different classes of object description, color1, color2 and color3
 
26
  self.vae = vae
27
  self.latent_size = latent_size
28
  self.sample_size = sample_size
29
+ self.text_encoder = text_encoder
30
+ self.text_tokenizer = text_tokenizer
31
 
32
  # channels for 1 image
33
  self.num_channels = int(self.unet.config.in_channels / 4)
34
  self.load_dictionaries_from_dataset()
35
+ self.register_modules(unet=unet, scheduler=scheduler, vae=vae, text_tokenizer=text_tokenizer, text_encoder=text_encoder)
36
 
37
  def load_dictionaries_from_dataset(self):
38
  dataset = load_dataset('frutiemax/rct_dataset')
 
179
 
180
  return torch.reshape(noise_batches, (batch_size, 1, self.num_channels*4, self.latent_size, self.latent_size)).to(dtype=torch.float16, device='cuda')
181
 
182
+ def test_generate_embeddings(self, object_description, color1, color2, color3) -> torch.Tensor:
183
+ batch_size = len(object_description)
184
+
185
+ embeddings = torch.Tensor(size=(batch_size, 77, 768))
186
+ for batch_index in range(batch_size):
187
+ prompt = f'{object_description[batch_index]},{color1[batch_index]},{color2[batch_index]}, {color3[batch_index]}'
188
+ tokens = self.text_tokenizer(prompt, \
189
+ padding="max_length", max_length=self.text_tokenizer.model_max_length, truncation=True, return_tensors="pt")
190
+ with torch.no_grad():
191
+ embeddings[batch_index] = self.text_encoder(tokens.input_ids.to('cuda'))[0]
192
+
193
+ return embeddings.to(dtype=torch.float16)
194
+
195
+ def generate_embeddings(self, object_description, color1, color2, color3) -> torch.Tensor:
196
+ batch_size = len(object_description)
197
+
198
+ embeddings = torch.Tensor(size=(batch_size, 77, 768 * 4))
199
+ for batch_index in range(batch_size):
200
+ object_description_tokens = self.text_tokenizer(object_description[batch_index], \
201
+ padding="max_length", max_length=self.text_tokenizer.model_max_length, truncation=True, return_tensors="pt")
202
+ color1_tokens = self.text_tokenizer(color1[batch_index], \
203
+ padding="max_length", max_length=self.text_tokenizer.model_max_length, truncation=True, return_tensors="pt")
204
+ color2_tokens = self.text_tokenizer(color2[batch_index], \
205
+ padding="max_length", max_length=self.text_tokenizer.model_max_length, truncation=True, return_tensors="pt")
206
+ color3_tokens = self.text_tokenizer(color3[batch_index], \
207
+ padding="max_length", max_length=self.text_tokenizer.model_max_length, truncation=True, return_tensors="pt")
208
+ with torch.no_grad():
209
+ object_description_embeddings = self.text_encoder(object_description_tokens.input_ids.to('cuda'))[0]
210
+ color1_embeddings = self.text_encoder(color1_tokens.input_ids.to('cuda'))[0]
211
+ color2_embeddings = self.text_encoder(color2_tokens.input_ids.to('cuda'))[0]
212
+ color3_embeddings = self.text_encoder(color3_tokens.input_ids.to('cuda'))[0]
213
+
214
+ emb = torch.cat([object_description_embeddings, color1_embeddings, color2_embeddings, color3_embeddings], dim=2)
215
+ embeddings[batch_index] = emb
216
+
217
+ return embeddings.to(dtype=torch.float16)
218
+
219
+ def validate_inputs(self, object_description : list[str], color1 : list[str], \
220
+ color2 : list[str], color3 : list[str], batch_size) -> tuple[bool, list[str], list[str], list[str], list[str]]:
221
+ # check if the labels sizes are correct
222
+ if len(object_description) != batch_size:
223
+ return False
224
+
225
+ if len(color1) != batch_size:
226
+ return False
227
+
228
+ if color2 == None:
229
+ color2 = ['none'] * batch_size
230
+ elif len(color2) != batch_size:
231
+ return False
232
+
233
+ if color3 == None:
234
+ color3 = ['none'] * batch_size
235
+ elif len(color3) != batch_size:
236
+ return False
237
+ return True, object_description, color1, color2, color3
238
+
239
+ def __call__(self, object_description : list[str], color1 : list[str], \
240
+ color2 : list[str] = None, color3 : list[str] = None, \
241
  batch_size=1, num_inference_steps=20, generator=torch.manual_seed(torch.random.seed())):
242
 
243
+ res, object_description, color1, color2, color3 = self.validate_inputs(object_description, color1, color2, color3, batch_size)
244
+ if res == False:
245
  return None
246
+ embeddings = self.test_generate_embeddings(object_description, color1, color2, color3)
247
+ embeddings = embeddings.to('cuda')
248
 
249
  # set the inference steps
250
  self.scheduler.set_timesteps(num_inference_steps)
 
257
  progress_bar.set_description(f'Inference step {epoch}')
258
 
259
  for batch_index in range(batch_size):
260
+ noise_batches[batch_index] = self.scheduler.scale_model_input(noise_batches[batch_index], timestep=t)
261
  with torch.no_grad():
262
+ noise_residual = self.unet(noise_batches[batch_index], t, encoder_hidden_states=embeddings).sample
263
  previous_noisy_sample = self.scheduler.step(noise_residual, t, noise_batches[batch_index]).prev_sample
264
  noise_batches[batch_index] = previous_noisy_sample
265
  progress_bar.update(1)
 
285
  image = (image.permute(1, 2, 0) * 255).to(torch.uint8).cpu().numpy()
286
  image = (image * 255).round().astype("uint8")
287
  image = Image.fromarray(image)
288
+ image.save(f'test{image_index}.png')
289
  output_images.append(image)
290
 
291
  # for now just return the images
test_pipeline.py CHANGED
@@ -1,20 +1,38 @@
1
  from rct_diffusion_pipeline import RCTDiffusionPipeline
2
  from diffusers import UNet2DConditionModel, DDPMScheduler, AutoencoderKL
3
  import torch
 
4
 
5
  torch_device = "cuda"
6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  unet = UNet2DConditionModel(sample_size=32, in_channels=16, out_channels=16, \
8
  down_block_types=('CrossAttnDownBlock2D', 'CrossAttnDownBlock2D', 'DownBlock2D'),\
9
- up_block_types=('UpBlock2D', 'CrossAttnUpBlock2D', 'CrossAttnUpBlock2D'), cross_attention_dim=160,
10
  block_out_channels=(64, 128, 256), norm_num_groups=32)
11
  unet = unet.to('cuda', dtype=torch.float16)
12
  scheduler = DDPMScheduler(num_train_timesteps=20)
13
  vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", use_safetensors=True)
14
  vae = vae.to('cuda', dtype=torch.float16)
15
 
16
- pipeline = RCTDiffusionPipeline(unet, scheduler, vae)
17
- output = pipeline([[('aleppo pine tree', 1.0)]], [[('dark green', 1.0)]])
18
  pipeline.save_pretrained('test')
19
 
20
  # from PIL import Image
 
1
  from rct_diffusion_pipeline import RCTDiffusionPipeline
2
  from diffusers import UNet2DConditionModel, DDPMScheduler, AutoencoderKL
3
  import torch
4
+ from transformers import CLIPTextModel, CLIPTokenizer
5
 
6
  torch_device = "cuda"
7
 
8
+ # test of text tokenizers
9
+ tokenizer = CLIPTokenizer.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="tokenizer")
10
+ text_encoder = CLIPTextModel.from_pretrained(
11
+ "CompVis/stable-diffusion-v1-4", subfolder="text_encoder", use_safetensors=True
12
+ ).to('cuda')
13
+
14
+ test1 = tokenizer(['aleppo pine tree, common oak tree'], padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
15
+ #test3 = tokenizer([1.0, 0.0, .05], is_split_into_words=True, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
16
+
17
+ with torch.no_grad():
18
+ test1 = text_encoder(test1.input_ids.to('cuda'))[0]
19
+
20
+ test2 = tokenizer('dark green', padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
21
+
22
+ with torch.no_grad():
23
+ test2 = text_encoder(test2.input_ids.to('cuda'))[0]
24
+
25
  unet = UNet2DConditionModel(sample_size=32, in_channels=16, out_channels=16, \
26
  down_block_types=('CrossAttnDownBlock2D', 'CrossAttnDownBlock2D', 'DownBlock2D'),\
27
+ up_block_types=('UpBlock2D', 'CrossAttnUpBlock2D', 'CrossAttnUpBlock2D'), cross_attention_dim=768*4,
28
  block_out_channels=(64, 128, 256), norm_num_groups=32)
29
  unet = unet.to('cuda', dtype=torch.float16)
30
  scheduler = DDPMScheduler(num_train_timesteps=20)
31
  vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", use_safetensors=True)
32
  vae = vae.to('cuda', dtype=torch.float16)
33
 
34
+ pipeline = RCTDiffusionPipeline(unet, scheduler, vae, tokenizer, text_encoder)
35
+ output = pipeline(['aleppo pine tree'], ['dark green'])
36
  pipeline.save_pretrained('test')
37
 
38
  # from PIL import Image
train_model.py CHANGED
@@ -11,6 +11,7 @@ from diffusers.optimization import get_cosine_schedule_with_warmup
11
  from tqdm.auto import tqdm
12
  from accelerate import Accelerator
13
  from diffusers import DDPMScheduler, UNet2DConditionModel, AutoencoderKL
 
14
 
15
  SAMPLE_SIZE = 256
16
  LATENT_SIZE = 32
@@ -18,12 +19,12 @@ SAMPLE_NUM_CHANNELS = 3
18
  LATENT_NUM_CHANNELS = 4
19
 
20
  def save_and_test(pipeline, epoch):
21
- outputs = pipeline([[('aleppo pine tree', 1.0)]], [[('dark green', 1.0)]])
22
  for image_index in range(len(outputs)):
23
  file_name = f'out{image_index}_{epoch}.png'
24
  outputs[image_index].save(file_name)
25
 
26
- model_file = f'rct_foliage_{epoch}.pth'
27
  pipeline.save_pretrained(model_file)
28
 
29
  def convert_images(dataset):
@@ -42,18 +43,18 @@ def convert_images(dataset):
42
  for entry in views[view_index]:
43
  image = entry['image']
44
 
45
- scale_factor = int(np.minimum(SAMPLE_SIZE / image.width, SAMPLE_SIZE / image.height))
46
- image = Image.resize(image, size=(scale_factor * image.width, scale_factor * image.height), resample=Resampling.NEAREST)
47
 
48
- new_image = PIL.Image.new('RGB', (SAMPLE_SIZE, SAMPLE_SIZE))
49
- new_image.paste(image, box=(int((SAMPLE_SIZE - image.width)/2), int((SAMPLE_SIZE - image.height)/2)))
50
  images.append(new_image)
51
  image_views.append(images)
52
 
53
  del views
54
 
55
  # convert those views in tensors
56
- targets = torch.Tensor(size=(num_images, 4, SAMPLE_NUM_CHANNELS, SAMPLE_SIZE, SAMPLE_SIZE)).to(dtype=torch.float16)
57
  pillow_to_tensor = T.ToTensor()
58
 
59
  for image_index in range(num_images):
@@ -62,7 +63,7 @@ def convert_images(dataset):
62
  del image_views
63
  del entries
64
 
65
- return torch.reshape(targets, (num_images, 4 * SAMPLE_NUM_CHANNELS, SAMPLE_SIZE, SAMPLE_SIZE))
66
 
67
  def convert_labels(dataset, model, num_images):
68
  # get the labels
@@ -96,80 +97,115 @@ def convert_labels(dataset, model, num_images):
96
  del dataset
97
  return class_labels.to(dtype=torch.float16, device='cuda')
98
 
99
- def train_model(batch_size=4, epochs=100, scheduler_num_timesteps=20, save_model_interval=10, start_learning_rate=1e-3, lr_warmup_steps=500):
100
  dataset = load_dataset('frutiemax/rct_dataset')
101
  dataset = dataset['train']
102
 
103
  targets = convert_images(dataset)
104
- num_images = int(dataset.num_rows / 4)
105
 
106
- unet = UNet2DConditionModel(sample_size=LATENT_SIZE, in_channels=LATENT_NUM_CHANNELS * 4, out_channels=LATENT_NUM_CHANNELS * 4, \
107
- down_block_types=('CrossAttnDownBlock2D', 'CrossAttnDownBlock2D', 'DownBlock2D'),\
108
- up_block_types=('UpBlock2D', 'CrossAttnUpBlock2D', 'CrossAttnUpBlock2D'), cross_attention_dim=160,
109
- block_out_channels=(64, 128, 256), norm_num_groups=32)
110
  unet = unet.to(dtype=torch.float16)
111
- scheduler = DDPMScheduler(num_train_timesteps=20)
 
 
 
 
112
  vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", use_safetensors=True)
113
  vae = vae.to(dtype=torch.float16)
114
 
115
- optimizer = torch.optim.Adam(unet.parameters(), lr=start_learning_rate)
116
  lr_scheduler = get_cosine_schedule_with_warmup(
117
  optimizer=optimizer,
118
  num_warmup_steps=lr_warmup_steps,
119
  num_training_steps=num_images * epochs
120
  )
121
- model = RCTDiffusionPipeline(unet, scheduler, vae)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  labels = convert_labels(dataset, model, num_images)
123
  del model
124
 
 
 
 
 
 
125
  # lets train for 100 epoch for each sprite in the dataset with a random noise level
126
  progress_bar = tqdm(total=epochs)
127
  accelerator = Accelerator(mixed_precision='fp16')
 
128
  unet, scheduler, lr_scheduler, vae = accelerator.prepare(unet, scheduler, lr_scheduler, vae)
129
 
 
 
 
130
  for epoch in range(epochs):
131
  # create a noisy version of each sprite
132
  for batch_index in range(0, num_images, batch_size):
133
- progress_bar.set_description(f'epoch={epoch}, batch_index={batch_index}')
134
  batch_end = np.minimum(num_images, batch_index + batch_size)
135
  clean_images = targets[batch_index:batch_end]
136
- clean_images = torch.reshape(clean_images, ((batch_end - batch_index), SAMPLE_NUM_CHANNELS * 4, SAMPLE_SIZE, SAMPLE_SIZE)).to(device='cuda', dtype=torch.float16)
 
137
 
138
  noise = torch.randn(clean_images.shape, dtype=torch.float16, device='cuda')
139
  timesteps = torch.randint(0, scheduler.config.num_train_timesteps, (batch_end - batch_index, )).to(device='cuda')
 
140
  #timesteps = timesteps.to(dtype=torch.int, device='cuda')
141
  noisy_images = scheduler.add_noise(clean_images, noise, timesteps)
142
- del clean_images
143
-
144
- # encode through the vae
145
- with accelerator.accumulate(unet):
146
- latent_images = torch.Tensor(size=(batch_end - batch_index, LATENT_NUM_CHANNELS * 4, LATENT_SIZE, LATENT_SIZE)).to(device='cuda', dtype=torch.float16)
147
- latent_noises = torch.Tensor(size=(batch_end - batch_index, LATENT_NUM_CHANNELS * 4, LATENT_SIZE, LATENT_SIZE)).to(device='cuda', dtype=torch.float16)
148
- for view_index in range(4):
149
- images = noisy_images[:, view_index*SAMPLE_NUM_CHANNELS:(view_index+1)*SAMPLE_NUM_CHANNELS]
150
- result = vae.encode(images).latent_dist.sample()
151
- latent_images[:, view_index*LATENT_NUM_CHANNELS:(view_index+1)*LATENT_NUM_CHANNELS] = result
152
-
153
- images = noise[:, view_index*SAMPLE_NUM_CHANNELS:(view_index+1)*SAMPLE_NUM_CHANNELS]
154
- result = vae.encode(images).latent_dist.sample()
155
- latent_noises[:, view_index*LATENT_NUM_CHANNELS:(view_index+1)*LATENT_NUM_CHANNELS] = result
156
-
157
- del noise
158
- del noisy_images
159
- unet_results = unet(latent_images, timesteps, labels[batch_index:batch_end])[0]
160
- unet_results = unet_results.to(dtype=torch.float16)
161
-
162
- loss = F.mse_loss(unet_results, latent_noises)
163
- accelerator.backward(loss)
164
- optimizer.step()
165
- lr_scheduler.step()
166
- optimizer.zero_grad()
 
 
 
 
 
 
167
 
168
  if (epoch + 1) % save_model_interval == 0:
169
- model = RCTDiffusionPipeline(accelerator.unwrap_model(unet), scheduler, vae)
170
  save_and_test(model, epoch)
171
  progress_bar.update(1)
172
 
173
 
174
  if __name__ == '__main__':
175
- train_model(1, save_model_interval=1)
 
11
  from tqdm.auto import tqdm
12
  from accelerate import Accelerator
13
  from diffusers import DDPMScheduler, UNet2DConditionModel, AutoencoderKL
14
+ from transformers import CLIPTextModel, CLIPTokenizer
15
 
16
  SAMPLE_SIZE = 256
17
  LATENT_SIZE = 32
 
19
  LATENT_NUM_CHANNELS = 4
20
 
21
  def save_and_test(pipeline, epoch):
22
+ outputs = pipeline(['aleppo pine tree'], ['dark green'])
23
  for image_index in range(len(outputs)):
24
  file_name = f'out{image_index}_{epoch}.png'
25
  outputs[image_index].save(file_name)
26
 
27
+ model_file = f'rct_foliage_{epoch}'
28
  pipeline.save_pretrained(model_file)
29
 
30
  def convert_images(dataset):
 
43
  for entry in views[view_index]:
44
  image = entry['image']
45
 
46
+ scale_factor = np.minimum(LATENT_SIZE / image.width, LATENT_SIZE / image.height)
47
+ image = Image.resize(image, size=(int(scale_factor * image.width), int(scale_factor * image.height)), resample=Resampling.NEAREST)
48
 
49
+ new_image = PIL.Image.new('RGBA', (LATENT_SIZE, LATENT_SIZE))
50
+ new_image.paste(image, box=(int((LATENT_SIZE - image.width)/2), int((LATENT_SIZE - image.height)/2)))
51
  images.append(new_image)
52
  image_views.append(images)
53
 
54
  del views
55
 
56
  # convert those views in tensors
57
+ targets = torch.Tensor(size=(num_images, 4, LATENT_NUM_CHANNELS, LATENT_SIZE, LATENT_SIZE)).to(dtype=torch.float16)
58
  pillow_to_tensor = T.ToTensor()
59
 
60
  for image_index in range(num_images):
 
63
  del image_views
64
  del entries
65
 
66
+ return torch.reshape(targets, (num_images, 4 * LATENT_NUM_CHANNELS, LATENT_SIZE, LATENT_SIZE))
67
 
68
  def convert_labels(dataset, model, num_images):
69
  # get the labels
 
97
  del dataset
98
  return class_labels.to(dtype=torch.float16, device='cuda')
99
 
100
+ def train_model(batch_size=4, total_images=None, epochs=100, scheduler_num_timesteps=20, save_model_interval=10, start_learning_rate=1e-3, lr_warmup_steps=1):
101
  dataset = load_dataset('frutiemax/rct_dataset')
102
  dataset = dataset['train']
103
 
104
  targets = convert_images(dataset)
105
+ num_images = int(dataset.num_rows / 4) if total_images == None else int(total_images / 4)
106
 
107
+ unet = UNet2DConditionModel(sample_size=LATENT_SIZE, in_channels=LATENT_NUM_CHANNELS*4, out_channels=LATENT_NUM_CHANNELS*4, \
108
+ down_block_types=("CrossAttnDownBlock2D","CrossAttnDownBlock2D","CrossAttnDownBlock2D", "DownBlock2D"),\
109
+ up_block_types=("UpBlock2D","CrossAttnUpBlock2D","CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), cross_attention_dim=768,
110
+ block_out_channels=(320, 640, 1280, 1280), norm_num_groups=32)
111
  unet = unet.to(dtype=torch.float16)
112
+ scheduler = DDPMScheduler(num_train_timesteps=scheduler_num_timesteps)
113
+ tokenizer = CLIPTokenizer.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="tokenizer")
114
+ text_encoder = CLIPTextModel.from_pretrained(
115
+ "CompVis/stable-diffusion-v1-4", subfolder="text_encoder", use_safetensors=True
116
+ ).to('cuda')
117
  vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", use_safetensors=True)
118
  vae = vae.to(dtype=torch.float16)
119
 
120
+ optimizer = torch.optim.SGD(unet.parameters(), lr=start_learning_rate)
121
  lr_scheduler = get_cosine_schedule_with_warmup(
122
  optimizer=optimizer,
123
  num_warmup_steps=lr_warmup_steps,
124
  num_training_steps=num_images * epochs
125
  )
126
+ model = RCTDiffusionPipeline(unet, scheduler, vae, tokenizer, text_encoder)
127
+
128
+ # get all the object descriptions, color1, color2, color3
129
+ object_descriptions = dataset['object_description']
130
+ colors1 = dataset['color1']
131
+ colors2 = dataset['color2']
132
+ colors3 = dataset['color3']
133
+
134
+ # we only need 1 of the 4 views
135
+ object_descriptions = [object_descriptions[desc_index] for desc_index in range(0, len(object_descriptions), 4)]
136
+ colors1 = [colors1[desc_index] for desc_index in range(0, len(colors1), 4)]
137
+ colors2 = [colors2[desc_index] for desc_index in range(0, len(colors2), 4)]
138
+ colors3 = [colors3[desc_index] for desc_index in range(0, len(colors3), 4)]
139
+ #embeddings = model.generate_embeddings(object_descriptions, colors1, colors2, colors3)
140
+ embeddings = model.test_generate_embeddings(object_descriptions, colors1, colors2, colors3)
141
+
142
  labels = convert_labels(dataset, model, num_images)
143
  del model
144
 
145
+ if total_images != None:
146
+ targets = targets[:int(total_images/4)]
147
+ label_indices = [index for index in range(0, total_images, 4)]
148
+ labels = labels[label_indices]
149
+
150
  # lets train for 100 epoch for each sprite in the dataset with a random noise level
151
  progress_bar = tqdm(total=epochs)
152
  accelerator = Accelerator(mixed_precision='fp16')
153
+ accelerator.clip_grad_norm_(unet.parameters(), 1.0)
154
  unet, scheduler, lr_scheduler, vae = accelerator.prepare(unet, scheduler, lr_scheduler, vae)
155
 
156
+ loss_fn = torch.nn.MSELoss()
157
+
158
+ tensor_to_pillow = T.ToPILImage()
159
  for epoch in range(epochs):
160
  # create a noisy version of each sprite
161
  for batch_index in range(0, num_images, batch_size):
 
162
  batch_end = np.minimum(num_images, batch_index + batch_size)
163
  clean_images = targets[batch_index:batch_end]
164
+ clean_images = torch.reshape(clean_images, ((batch_end - batch_index), LATENT_NUM_CHANNELS * 4, LATENT_SIZE, LATENT_SIZE)).\
165
+ to(device='cuda', dtype=torch.float16)
166
 
167
  noise = torch.randn(clean_images.shape, dtype=torch.float16, device='cuda')
168
  timesteps = torch.randint(0, scheduler.config.num_train_timesteps, (batch_end - batch_index, )).to(device='cuda')
169
+
170
  #timesteps = timesteps.to(dtype=torch.int, device='cuda')
171
  noisy_images = scheduler.add_noise(clean_images, noise, timesteps)
172
+
173
+ # with accelerator.accumulate(unet):
174
+ # assert not torch.any(torch.isnan(timesteps))
175
+
176
+ # batch_embeddings = embeddings[batch_index:batch_end]
177
+ # batch_embeddings = batch_embeddings.to('cuda')
178
+
179
+ # optimizer.zero_grad()
180
+ # unet_results = unet(noisy_images, timesteps, batch_embeddings).sample
181
+ # unet_results = unet_results.to(dtype=torch.float16)
182
+
183
+ # loss = loss_fn(unet_results, noise)
184
+ # accelerator.backward(loss)
185
+
186
+ # optimizer.step()
187
+ # lr_scheduler.step()
188
+ # optimizer.zero_grad()
189
+
190
+ batch_embeddings = embeddings[batch_index:batch_end]
191
+ batch_embeddings = batch_embeddings.to('cuda')
192
+
193
+ optimizer.zero_grad()
194
+ unet_results = unet(noisy_images, timesteps, batch_embeddings).sample
195
+ unet_results = unet_results.to(dtype=torch.float16)
196
+ loss = loss_fn(unet_results, noise)
197
+ loss.backward()
198
+ optimizer.step()
199
+ lr_scheduler.step()
200
+ optimizer.zero_grad()
201
+
202
+ progress_bar.set_description(f'epoch={epoch}, batch_index={batch_index}, last_loss={loss.item()}')
203
 
204
  if (epoch + 1) % save_model_interval == 0:
205
+ model = RCTDiffusionPipeline(accelerator.unwrap_model(unet), scheduler, vae, tokenizer, text_encoder)
206
  save_and_test(model, epoch)
207
  progress_bar.update(1)
208
 
209
 
210
  if __name__ == '__main__':
211
+ train_model(1, total_images=4, save_model_interval=1)