frutiemax commited on
Commit
42f8b67
·
1 Parent(s): 82ebedf

Use float16 for inference and float32 for training

Browse files
Files changed (3) hide show
  1. rct_diffusion_pipeline.py +24 -27
  2. test_pipeline.py +11 -83
  3. train_model.py +60 -97
rct_diffusion_pipeline.py CHANGED
@@ -30,7 +30,7 @@ class RCTDiffusionPipeline(DiffusionPipeline):
30
  self.text_tokenizer = text_tokenizer
31
 
32
  # channels for 1 image
33
- self.num_channels = int(self.unet.config.in_channels / 4)
34
  self.load_dictionaries_from_dataset()
35
  self.register_modules(unet=unet, scheduler=scheduler, vae=vae, text_tokenizer=text_tokenizer, text_encoder=text_encoder)
36
 
@@ -171,13 +171,12 @@ class RCTDiffusionPipeline(DiffusionPipeline):
171
  return self.pack_labels_to_tensor(batch_size, object_descriptions, colors1, colors2, colors3).to(device='cuda',dtype=torch.float16)
172
 
173
  def generate_noise_batches(self, batch_size):
174
- noise_batches = torch.Tensor(size=(batch_size, 4, self.num_channels, self.latent_size, self.latent_size)).to(dtype=torch.float16, device='cuda')
175
  for batch_index in range(batch_size):
176
- for view_index in range(4):
177
- noise = torch.randn(self.num_channels, self.latent_size, self.latent_size).to(dtype=torch.float16, device='cuda')
178
- noise_batches[batch_index, view_index] = noise
179
 
180
- return torch.reshape(noise_batches, (batch_size, 1, self.num_channels*4, self.latent_size, self.latent_size)).to(dtype=torch.float16, device='cuda')
181
 
182
  def test_generate_embeddings(self, object_description, color1, color2, color3) -> torch.Tensor:
183
  batch_size = len(object_description)
@@ -190,7 +189,7 @@ class RCTDiffusionPipeline(DiffusionPipeline):
190
  with torch.no_grad():
191
  embeddings[batch_index] = self.text_encoder(tokens.input_ids.to('cuda'))[0]
192
 
193
- return embeddings.to(dtype=torch.float16)
194
 
195
  def generate_embeddings(self, object_description, color1, color2, color3) -> torch.Tensor:
196
  batch_size = len(object_description)
@@ -244,11 +243,11 @@ class RCTDiffusionPipeline(DiffusionPipeline):
244
  if res == False:
245
  return None
246
  embeddings = self.test_generate_embeddings(object_description, color1, color2, color3)
247
- embeddings = embeddings.to('cuda')
248
 
249
  # set the inference steps
250
  self.scheduler.set_timesteps(num_inference_steps)
251
- noise_batches = self.generate_noise_batches(batch_size)
252
 
253
  # now call the model for the n interations
254
  progress_bar = tqdm(total=num_inference_steps)
@@ -257,36 +256,34 @@ class RCTDiffusionPipeline(DiffusionPipeline):
257
  progress_bar.set_description(f'Inference step {epoch}')
258
 
259
  for batch_index in range(batch_size):
260
- noise_batches[batch_index] = self.scheduler.scale_model_input(noise_batches[batch_index], timestep=t)
261
  with torch.no_grad():
262
- noise_residual = self.unet(noise_batches[batch_index], t, encoder_hidden_states=embeddings).sample
263
- previous_noisy_sample = self.scheduler.step(noise_residual, t, noise_batches[batch_index]).prev_sample
264
  noise_batches[batch_index] = previous_noisy_sample
265
  progress_bar.update(1)
266
  epoch = epoch + 1
267
 
268
  # reshape the data so we get back 4 RGB images
269
- noise_batches = torch.reshape(noise_batches, (batch_size, 4, self.num_channels, self.latent_size, self.latent_size))
270
- images = torch.Tensor(size=(batch_size, 4, 3, self.sample_size, self.sample_size))
 
271
 
272
  with torch.no_grad():
273
- for image_index in range(4):
274
- image = noise_batches[:, image_index]
275
- result = self.vae.decode(image).sample
276
- images[:, image_index] = result
277
 
278
  # convert those tensors to PIL images
279
  output_images = []
280
  for batch_index in range(batch_size):
281
- for image_index in range(4):
282
- # run these into the vae decoder
283
- image = images[batch_index, image_index]
284
- image = (image / 2 + 0.5).clamp(0, 1).squeeze()
285
- image = (image.permute(1, 2, 0) * 255).to(torch.uint8).cpu().numpy()
286
- image = (image * 255).round().astype("uint8")
287
- image = Image.fromarray(image)
288
- image.save(f'test{image_index}.png')
289
- output_images.append(image)
290
 
291
  # for now just return the images
292
  return output_images
 
30
  self.text_tokenizer = text_tokenizer
31
 
32
  # channels for 1 image
33
+ self.num_channels = int(self.unet.config.in_channels)
34
  self.load_dictionaries_from_dataset()
35
  self.register_modules(unet=unet, scheduler=scheduler, vae=vae, text_tokenizer=text_tokenizer, text_encoder=text_encoder)
36
 
 
171
  return self.pack_labels_to_tensor(batch_size, object_descriptions, colors1, colors2, colors3).to(device='cuda',dtype=torch.float16)
172
 
173
  def generate_noise_batches(self, batch_size):
174
+ noise_batches = torch.Tensor(size=(batch_size, self.num_channels, self.latent_size, self.latent_size)).to(dtype=torch.float16, device='cuda')
175
  for batch_index in range(batch_size):
176
+ noise = torch.randn(self.num_channels, self.latent_size, self.latent_size).to(dtype=torch.float16, device='cuda')
177
+ noise_batches[batch_index] = noise
 
178
 
179
+ return torch.reshape(noise_batches, (batch_size, self.num_channels, self.latent_size, self.latent_size)).to(dtype=torch.float16, device='cuda')
180
 
181
  def test_generate_embeddings(self, object_description, color1, color2, color3) -> torch.Tensor:
182
  batch_size = len(object_description)
 
189
  with torch.no_grad():
190
  embeddings[batch_index] = self.text_encoder(tokens.input_ids.to('cuda'))[0]
191
 
192
+ return embeddings
193
 
194
  def generate_embeddings(self, object_description, color1, color2, color3) -> torch.Tensor:
195
  batch_size = len(object_description)
 
243
  if res == False:
244
  return None
245
  embeddings = self.test_generate_embeddings(object_description, color1, color2, color3)
246
+ embeddings = embeddings.to(device='cuda', dtype=torch.float16)
247
 
248
  # set the inference steps
249
  self.scheduler.set_timesteps(num_inference_steps)
250
+ noise_batches = self.generate_noise_batches(batch_size).to(dtype=torch.float16)
251
 
252
  # now call the model for the n interations
253
  progress_bar = tqdm(total=num_inference_steps)
 
256
  progress_bar.set_description(f'Inference step {epoch}')
257
 
258
  for batch_index in range(batch_size):
259
+ noise_batch = self.scheduler.scale_model_input(noise_batches, timestep=t)
260
  with torch.no_grad():
261
+ noise_residual = self.unet(noise_batch, t, encoder_hidden_states=embeddings).sample
262
+ previous_noisy_sample = self.scheduler.step(noise_residual, t, noise_batch).prev_sample
263
  noise_batches[batch_index] = previous_noisy_sample
264
  progress_bar.update(1)
265
  epoch = epoch + 1
266
 
267
  # reshape the data so we get back 4 RGB images
268
+ noise_batches = torch.reshape(noise_batches, (batch_size, self.num_channels, self.latent_size, self.latent_size))
269
+ noise_batches = noise_batches.to('cuda')
270
+ images = torch.Tensor(size=(batch_size, 3, self.sample_size, self.sample_size)).to('cuda')
271
 
272
  with torch.no_grad():
273
+ image = noise_batches
274
+ result = self.vae.decode(image).sample
275
+ images = result
 
276
 
277
  # convert those tensors to PIL images
278
  output_images = []
279
  for batch_index in range(batch_size):
280
+ image = images[batch_index]
281
+ image = (image / 2 + 0.5).clamp(0, 1).squeeze()
282
+ image = (image.permute(1, 2, 0) * 255).to(torch.uint8).cpu().numpy()
283
+ image = (image * 255).round().astype("uint8")
284
+ image = Image.fromarray(image)
285
+ image.save(f'test{batch_index}.png')
286
+ output_images.append(image)
 
 
287
 
288
  # for now just return the images
289
  return output_images
test_pipeline.py CHANGED
@@ -2,6 +2,7 @@ from rct_diffusion_pipeline import RCTDiffusionPipeline
2
  from diffusers import UNet2DConditionModel, DDPMScheduler, AutoencoderKL
3
  import torch
4
  from transformers import CLIPTextModel, CLIPTokenizer
 
5
 
6
  torch_device = "cuda"
7
 
@@ -22,11 +23,17 @@ test2 = tokenizer('dark green', padding="max_length", max_length=tokenizer.model
22
  with torch.no_grad():
23
  test2 = text_encoder(test2.input_ids.to('cuda'))[0]
24
 
25
- unet = UNet2DConditionModel(sample_size=32, in_channels=16, out_channels=16, \
26
- down_block_types=('CrossAttnDownBlock2D', 'CrossAttnDownBlock2D', 'DownBlock2D'),\
27
- up_block_types=('UpBlock2D', 'CrossAttnUpBlock2D', 'CrossAttnUpBlock2D'), cross_attention_dim=768*4,
28
- block_out_channels=(64, 128, 256), norm_num_groups=32)
29
  unet = unet.to('cuda', dtype=torch.float16)
 
 
 
 
 
 
30
  scheduler = DDPMScheduler(num_train_timesteps=20)
31
  vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", use_safetensors=True)
32
  vae = vae.to('cuda', dtype=torch.float16)
@@ -34,83 +41,4 @@ vae = vae.to('cuda', dtype=torch.float16)
34
  pipeline = RCTDiffusionPipeline(unet, scheduler, vae, tokenizer, text_encoder)
35
  output = pipeline(['aleppo pine tree'], ['dark green'])
36
  pipeline.save_pretrained('test')
37
-
38
- # from PIL import Image
39
- # import torch
40
- # from transformers import CLIPTextModel, CLIPTokenizer
41
- # from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler
42
-
43
- # vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae", use_safetensors=True)
44
- # tokenizer = CLIPTokenizer.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="tokenizer")
45
- # text_encoder = CLIPTextModel.from_pretrained(
46
- # "CompVis/stable-diffusion-v1-4", subfolder="text_encoder", use_safetensors=True
47
- # )
48
- # unet = UNet2DConditionModel.from_pretrained(
49
- # "CompVis/stable-diffusion-v1-4", subfolder="unet", use_safetensors=True
50
- # )
51
-
52
- # from diffusers import UniPCMultistepScheduler
53
-
54
- # scheduler = UniPCMultistepScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
55
- # torch_device = "cuda"
56
- # vae.to(torch_device)
57
- # text_encoder.to(torch_device)
58
- # unet.to(torch_device)
59
-
60
- # prompt = ["a photograph of an astronaut riding a horse"]
61
- # height = 512 # default height of Stable Diffusion
62
- # width = 512 # default width of Stable Diffusion
63
- # num_inference_steps = 25 # Number of denoising steps
64
- # guidance_scale = 7.5 # Scale for classifier-free guidance
65
- # generator = torch.manual_seed(0) # Seed generator to create the inital latent noise
66
- # batch_size = len(prompt)
67
-
68
- # text_input = tokenizer(
69
- # prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt"
70
- # )
71
-
72
- # with torch.no_grad():
73
- # text_embeddings = text_encoder(text_input.input_ids.to(torch_device))[0]
74
-
75
- # text_input = tokenizer(
76
- # prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt"
77
- # )
78
-
79
- # with torch.no_grad():
80
- # text_embeddings = text_encoder(text_input.input_ids.to(torch_device))[0]
81
-
82
- # max_length = text_input.input_ids.shape[-1]
83
- # uncond_input = tokenizer([""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt")
84
- # uncond_embeddings = text_encoder(uncond_input.input_ids.to(torch_device))[0]
85
-
86
- # text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
87
-
88
- # latents = torch.randn(
89
- # (batch_size, unet.in_channels, height // 8, width // 8),
90
- # generator=generator,
91
- # )
92
- # latents = latents.to(torch_device)
93
-
94
- # latents = latents * scheduler.init_noise_sigma
95
-
96
- # from tqdm.auto import tqdm
97
-
98
- # scheduler.set_timesteps(num_inference_steps)
99
-
100
- # for t in tqdm(scheduler.timesteps):
101
- # # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
102
- # latent_model_input = torch.cat([latents] * 2)
103
-
104
- # latent_model_input = scheduler.scale_model_input(latent_model_input, timestep=t)
105
-
106
- # # predict the noise residual
107
- # with torch.no_grad():
108
- # noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
109
-
110
- # # perform guidance
111
- # noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
112
- # noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
113
-
114
- # # compute the previous noisy sample x_t -> x_t-1
115
- # latents = scheduler.step(noise_pred, t, latents).prev_sample
116
  print('test')
 
2
  from diffusers import UNet2DConditionModel, DDPMScheduler, AutoencoderKL
3
  import torch
4
  from transformers import CLIPTextModel, CLIPTokenizer
5
+ import torch.nn as nn
6
 
7
  torch_device = "cuda"
8
 
 
23
  with torch.no_grad():
24
  test2 = text_encoder(test2.input_ids.to('cuda'))[0]
25
 
26
+ unet = UNet2DConditionModel(sample_size=32, in_channels=4, out_channels=4, \
27
+ down_block_types=("CrossAttnDownBlock2D","CrossAttnDownBlock2D","CrossAttnDownBlock2D", "DownBlock2D"),\
28
+ up_block_types=("UpBlock2D","CrossAttnUpBlock2D","CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), cross_attention_dim=768,
29
+ block_out_channels=(320, 640, 1280, 1280), norm_num_groups=32)
30
  unet = unet.to('cuda', dtype=torch.float16)
31
+
32
+ # put float32 for the accumulation
33
+ for layer in unet.modules():
34
+ if isinstance(layer, nn.BatchNorm2d):
35
+ layer.float()
36
+
37
  scheduler = DDPMScheduler(num_train_timesteps=20)
38
  vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", use_safetensors=True)
39
  vae = vae.to('cuda', dtype=torch.float16)
 
41
  pipeline = RCTDiffusionPipeline(unet, scheduler, vae, tokenizer, text_encoder)
42
  output = pipeline(['aleppo pine tree'], ['dark green'])
43
  pipeline.save_pretrained('test')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  print('test')
train_model.py CHANGED
@@ -12,11 +12,13 @@ from tqdm.auto import tqdm
12
  from accelerate import Accelerator
13
  from diffusers import DDPMScheduler, UNet2DConditionModel, AutoencoderKL
14
  from transformers import CLIPTextModel, CLIPTokenizer
 
15
 
16
  SAMPLE_SIZE = 256
17
  LATENT_SIZE = 32
18
  SAMPLE_NUM_CHANNELS = 3
19
  LATENT_NUM_CHANNELS = 4
 
20
 
21
  def save_and_test(pipeline, epoch):
22
  outputs = pipeline(['aleppo pine tree'], ['dark green'])
@@ -28,42 +30,22 @@ def save_and_test(pipeline, epoch):
28
  pipeline.save_pretrained(model_file)
29
 
30
  def convert_images(dataset):
31
- # let's get all the entries for the 4 views split in four lists
32
- views = []
33
- num_images = int(dataset.num_rows / 4)
34
-
35
- for view_index in range(4):
36
- entries = [entry for entry in dataset if entry['view'] == view_index]
37
- views.append(entries)
38
-
39
- # convert those images to 256x256 by cropping and scaling up the image
40
- image_views = []
41
- for view_index in range(4):
42
- images = []
43
- for entry in views[view_index]:
44
- image = entry['image']
45
-
46
- scale_factor = np.minimum(LATENT_SIZE / image.width, LATENT_SIZE / image.height)
47
- image = Image.resize(image, size=(int(scale_factor * image.width), int(scale_factor * image.height)), resample=Resampling.NEAREST)
48
-
49
- new_image = PIL.Image.new('RGBA', (LATENT_SIZE, LATENT_SIZE))
50
- new_image.paste(image, box=(int((LATENT_SIZE - image.width)/2), int((LATENT_SIZE - image.height)/2)))
51
- images.append(new_image)
52
- image_views.append(images)
53
-
54
- del views
55
 
56
- # convert those views in tensors
57
- targets = torch.Tensor(size=(num_images, 4, LATENT_NUM_CHANNELS, LATENT_SIZE, LATENT_SIZE)).to(dtype=torch.float16)
58
- pillow_to_tensor = T.ToTensor()
 
 
59
 
60
- for image_index in range(num_images):
61
- for view_index in range(4):
62
- targets[image_index, view_index] = pillow_to_tensor(image_views[view_index][image_index]).to(dtype=torch.float16)
63
- del image_views
64
- del entries
65
-
66
- return torch.reshape(targets, (num_images, 4 * LATENT_NUM_CHANNELS, LATENT_SIZE, LATENT_SIZE))
67
 
68
  def convert_labels(dataset, model, num_images):
69
  # get the labels
@@ -97,115 +79,96 @@ def convert_labels(dataset, model, num_images):
97
  del dataset
98
  return class_labels.to(dtype=torch.float16, device='cuda')
99
 
100
- def train_model(batch_size=4, total_images=None, epochs=100, scheduler_num_timesteps=20, save_model_interval=10, start_learning_rate=1e-3, lr_warmup_steps=1):
101
- dataset = load_dataset('frutiemax/rct_dataset')
102
- dataset = dataset['train']
 
 
 
 
103
 
104
- targets = convert_images(dataset)
 
 
105
  num_images = int(dataset.num_rows / 4) if total_images == None else int(total_images / 4)
106
 
107
- unet = UNet2DConditionModel(sample_size=LATENT_SIZE, in_channels=LATENT_NUM_CHANNELS*4, out_channels=LATENT_NUM_CHANNELS*4, \
108
  down_block_types=("CrossAttnDownBlock2D","CrossAttnDownBlock2D","CrossAttnDownBlock2D", "DownBlock2D"),\
109
  up_block_types=("UpBlock2D","CrossAttnUpBlock2D","CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), cross_attention_dim=768,
110
- block_out_channels=(320, 640, 1280, 1280), norm_num_groups=32)
111
- unet = unet.to(dtype=torch.float16)
 
 
 
 
 
 
112
  scheduler = DDPMScheduler(num_train_timesteps=scheduler_num_timesteps)
113
  tokenizer = CLIPTokenizer.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="tokenizer")
114
  text_encoder = CLIPTextModel.from_pretrained(
115
  "CompVis/stable-diffusion-v1-4", subfolder="text_encoder", use_safetensors=True
116
  ).to('cuda')
117
  vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", use_safetensors=True)
118
- vae = vae.to(dtype=torch.float16)
119
 
120
- optimizer = torch.optim.SGD(unet.parameters(), lr=start_learning_rate)
121
  lr_scheduler = get_cosine_schedule_with_warmup(
122
  optimizer=optimizer,
123
  num_warmup_steps=lr_warmup_steps,
124
  num_training_steps=num_images * epochs
125
  )
126
  model = RCTDiffusionPipeline(unet, scheduler, vae, tokenizer, text_encoder)
127
-
128
- # get all the object descriptions, color1, color2, color3
129
- object_descriptions = dataset['object_description']
130
- colors1 = dataset['color1']
131
- colors2 = dataset['color2']
132
- colors3 = dataset['color3']
133
-
134
- # we only need 1 of the 4 views
135
- object_descriptions = [object_descriptions[desc_index] for desc_index in range(0, len(object_descriptions), 4)]
136
- colors1 = [colors1[desc_index] for desc_index in range(0, len(colors1), 4)]
137
- colors2 = [colors2[desc_index] for desc_index in range(0, len(colors2), 4)]
138
- colors3 = [colors3[desc_index] for desc_index in range(0, len(colors3), 4)]
139
- #embeddings = model.generate_embeddings(object_descriptions, colors1, colors2, colors3)
140
- embeddings = model.test_generate_embeddings(object_descriptions, colors1, colors2, colors3)
141
-
142
- labels = convert_labels(dataset, model, num_images)
143
- del model
144
-
145
- if total_images != None:
146
- targets = targets[:int(total_images/4)]
147
- label_indices = [index for index in range(0, total_images, 4)]
148
- labels = labels[label_indices]
149
 
150
  # lets train for 100 epoch for each sprite in the dataset with a random noise level
151
  progress_bar = tqdm(total=epochs)
152
- accelerator = Accelerator(mixed_precision='fp16')
153
- accelerator.clip_grad_norm_(unet.parameters(), 1.0)
154
- unet, scheduler, lr_scheduler, vae = accelerator.prepare(unet, scheduler, lr_scheduler, vae)
155
 
156
  loss_fn = torch.nn.MSELoss()
157
 
158
  tensor_to_pillow = T.ToPILImage()
159
  for epoch in range(epochs):
160
  # create a noisy version of each sprite
161
- for batch_index in range(0, num_images, batch_size):
162
- batch_end = np.minimum(num_images, batch_index + batch_size)
163
- clean_images = targets[batch_index:batch_end]
164
- clean_images = torch.reshape(clean_images, ((batch_end - batch_index), LATENT_NUM_CHANNELS * 4, LATENT_SIZE, LATENT_SIZE)).\
165
- to(device='cuda', dtype=torch.float16)
 
166
 
167
- noise = torch.randn(clean_images.shape, dtype=torch.float16, device='cuda')
168
- timesteps = torch.randint(0, scheduler.config.num_train_timesteps, (batch_end - batch_index, )).to(device='cuda')
169
 
170
  #timesteps = timesteps.to(dtype=torch.int, device='cuda')
171
  noisy_images = scheduler.add_noise(clean_images, noise, timesteps)
172
-
173
- # with accelerator.accumulate(unet):
174
- # assert not torch.any(torch.isnan(timesteps))
175
-
176
- # batch_embeddings = embeddings[batch_index:batch_end]
177
- # batch_embeddings = batch_embeddings.to('cuda')
178
-
179
- # optimizer.zero_grad()
180
- # unet_results = unet(noisy_images, timesteps, batch_embeddings).sample
181
- # unet_results = unet_results.to(dtype=torch.float16)
182
-
183
- # loss = loss_fn(unet_results, noise)
184
- # accelerator.backward(loss)
185
-
186
- # optimizer.step()
187
- # lr_scheduler.step()
188
- # optimizer.zero_grad()
189
 
190
- batch_embeddings = embeddings[batch_index:batch_end]
191
  batch_embeddings = batch_embeddings.to('cuda')
192
 
193
  optimizer.zero_grad()
194
  unet_results = unet(noisy_images, timesteps, batch_embeddings).sample
195
- unet_results = unet_results.to(dtype=torch.float16)
196
  loss = loss_fn(unet_results, noise)
197
  loss.backward()
198
  optimizer.step()
199
  lr_scheduler.step()
200
  optimizer.zero_grad()
201
 
202
- progress_bar.set_description(f'epoch={epoch}, batch_index={batch_index}, last_loss={loss.item()}')
203
 
204
  if (epoch + 1) % save_model_interval == 0:
205
- model = RCTDiffusionPipeline(accelerator.unwrap_model(unet), scheduler, vae, tokenizer, text_encoder)
 
 
206
  save_and_test(model, epoch)
 
 
 
 
 
 
207
  progress_bar.update(1)
208
 
209
 
210
  if __name__ == '__main__':
211
- train_model(1, total_images=4, save_model_interval=1)
 
12
  from accelerate import Accelerator
13
  from diffusers import DDPMScheduler, UNet2DConditionModel, AutoencoderKL
14
  from transformers import CLIPTextModel, CLIPTokenizer
15
+ import torch.nn as nn
16
 
17
  SAMPLE_SIZE = 256
18
  LATENT_SIZE = 32
19
  SAMPLE_NUM_CHANNELS = 3
20
  LATENT_NUM_CHANNELS = 4
21
+ from torchvision import transforms
22
 
23
  def save_and_test(pipeline, epoch):
24
  outputs = pipeline(['aleppo pine tree'], ['dark green'])
 
30
  pipeline.save_pretrained(model_file)
31
 
32
  def convert_images(dataset):
33
+ preprocess = transforms.Compose(
34
+ [
35
+ transforms.Resize((LATENT_SIZE, LATENT_SIZE)),
36
+ transforms.ToTensor(),
37
+ transforms.Normalize([0.5], [0.5]),
38
+ ]
39
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
+ images = [preprocess(image.convert("RGBA")) for image in dataset["image"]]
42
+ object_descriptions = [obj_desc for obj_desc in dataset["object_description"]]
43
+ colors1 = [color1 for color1 in dataset['color1']]
44
+ colors2 = [color1 for color1 in dataset['color2']]
45
+ colors3 = [color1 for color1 in dataset['color3']]
46
 
47
+ return {"image": images, 'object_description':object_descriptions, 'color1':colors1, \
48
+ 'color2':colors2, 'color3':colors3}
 
 
 
 
 
49
 
50
  def convert_labels(dataset, model, num_images):
51
  # get the labels
 
79
  del dataset
80
  return class_labels.to(dtype=torch.float16, device='cuda')
81
 
82
+ def create_embeddings(dataset, model):
83
+ object_descriptions = dataset['object_description']
84
+ colors1 = dataset['color1']
85
+ colors2 = dataset['color2']
86
+ colors3 = dataset['color3']
87
+ return model.test_generate_embeddings(object_descriptions, colors1, colors2, colors3)
88
+
89
 
90
+ def train_model(batch_size=4, total_images=-1, epochs=100, scheduler_num_timesteps=20, save_model_interval=10, start_learning_rate=1e-3, lr_warmup_steps=1):
91
+ dataset = load_dataset('frutiemax/rct_dataset', split=f'train[0:{total_images}]')
92
+ dataset.set_transform(convert_images)
93
  num_images = int(dataset.num_rows / 4) if total_images == None else int(total_images / 4)
94
 
95
+ unet = UNet2DConditionModel(sample_size=LATENT_SIZE, in_channels=LATENT_NUM_CHANNELS, out_channels=LATENT_NUM_CHANNELS, \
96
  down_block_types=("CrossAttnDownBlock2D","CrossAttnDownBlock2D","CrossAttnDownBlock2D", "DownBlock2D"),\
97
  up_block_types=("UpBlock2D","CrossAttnUpBlock2D","CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), cross_attention_dim=768,
98
+ block_out_channels=(128, 256, 512, 512), norm_num_groups=32)
99
+ unet = unet.to(dtype=torch.float32)
100
+
101
+ #https://discuss.pytorch.org/t/training-with-half-precision/11815
102
+ for layer in unet.modules():
103
+ if isinstance(layer, nn.BatchNorm2d):
104
+ layer.float()
105
+
106
  scheduler = DDPMScheduler(num_train_timesteps=scheduler_num_timesteps)
107
  tokenizer = CLIPTokenizer.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="tokenizer")
108
  text_encoder = CLIPTextModel.from_pretrained(
109
  "CompVis/stable-diffusion-v1-4", subfolder="text_encoder", use_safetensors=True
110
  ).to('cuda')
111
  vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", use_safetensors=True)
112
+ vae = vae.to(dtype=torch.float16, device='cuda')
113
 
114
+ optimizer = torch.optim.AdamW(unet.parameters(), lr=start_learning_rate)
115
  lr_scheduler = get_cosine_schedule_with_warmup(
116
  optimizer=optimizer,
117
  num_warmup_steps=lr_warmup_steps,
118
  num_training_steps=num_images * epochs
119
  )
120
  model = RCTDiffusionPipeline(unet, scheduler, vae, tokenizer, text_encoder)
121
+ unet = unet.to('cuda')
122
+
123
+ train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
 
125
  # lets train for 100 epoch for each sprite in the dataset with a random noise level
126
  progress_bar = tqdm(total=epochs)
 
 
 
127
 
128
  loss_fn = torch.nn.MSELoss()
129
 
130
  tensor_to_pillow = T.ToPILImage()
131
  for epoch in range(epochs):
132
  # create a noisy version of each sprite
133
+ for step, batch in enumerate(train_dataloader):
134
+ clean_images = batch['image']
135
+ batch_size = clean_images.size(0)
136
+ embeddings = create_embeddings(batch, model)
137
+ clean_images = torch.reshape(clean_images, (batch['image'].size(0), LATENT_NUM_CHANNELS, LATENT_SIZE, LATENT_SIZE)).\
138
+ to(device='cuda')
139
 
140
+ noise = torch.randn(clean_images.shape, dtype=torch.float32, device='cuda')
141
+ timesteps = torch.randint(0, scheduler.config.num_train_timesteps, (batch_size, )).to(device='cuda')
142
 
143
  #timesteps = timesteps.to(dtype=torch.int, device='cuda')
144
  noisy_images = scheduler.add_noise(clean_images, noise, timesteps)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
 
146
+ batch_embeddings = embeddings
147
  batch_embeddings = batch_embeddings.to('cuda')
148
 
149
  optimizer.zero_grad()
150
  unet_results = unet(noisy_images, timesteps, batch_embeddings).sample
 
151
  loss = loss_fn(unet_results, noise)
152
  loss.backward()
153
  optimizer.step()
154
  lr_scheduler.step()
155
  optimizer.zero_grad()
156
 
157
+ progress_bar.set_description(f'epoch={epoch}, batch_index={step}, last_loss={loss.item()}')
158
 
159
  if (epoch + 1) % save_model_interval == 0:
160
+ # inference in float16
161
+ model = RCTDiffusionPipeline(unet.to(dtype=torch.float16), scheduler, \
162
+ vae.to(dtype=torch.float16), tokenizer, text_encoder.to(dtype=torch.float16))
163
  save_and_test(model, epoch)
164
+
165
+ # training in float32
166
+ unet.to(dtype=torch.float32)
167
+ vae.to(dtype=torch.float32)
168
+ text_encoder.to(dtype=torch.float32)
169
+
170
  progress_bar.update(1)
171
 
172
 
173
  if __name__ == '__main__':
174
+ train_model(1, total_images=4, save_model_interval=100, epochs=1000)