frutiemax commited on
Commit
9bde8da
1 Parent(s): 42f8b67

Use vae for encoding and decoding for training

Browse files
Files changed (3) hide show
  1. rct_diffusion_pipeline.py +21 -10
  2. test_pipeline.py +4 -2
  3. train_model.py +29 -16
rct_diffusion_pipeline.py CHANGED
@@ -172,8 +172,12 @@ class RCTDiffusionPipeline(DiffusionPipeline):
172
 
173
  def generate_noise_batches(self, batch_size):
174
  noise_batches = torch.Tensor(size=(batch_size, self.num_channels, self.latent_size, self.latent_size)).to(dtype=torch.float16, device='cuda')
 
 
 
 
175
  for batch_index in range(batch_size):
176
- noise = torch.randn(self.num_channels, self.latent_size, self.latent_size).to(dtype=torch.float16, device='cuda')
177
  noise_batches[batch_index] = noise
178
 
179
  return torch.reshape(noise_batches, (batch_size, self.num_channels, self.latent_size, self.latent_size)).to(dtype=torch.float16, device='cuda')
@@ -237,7 +241,11 @@ class RCTDiffusionPipeline(DiffusionPipeline):
237
 
238
  def __call__(self, object_description : list[str], color1 : list[str], \
239
  color2 : list[str] = None, color3 : list[str] = None, \
240
- batch_size=1, num_inference_steps=20, generator=torch.manual_seed(torch.random.seed())):
 
 
 
 
241
 
242
  res, object_description, color1, color2, color3 = self.validate_inputs(object_description, color1, color2, color3, batch_size)
243
  if res == False:
@@ -268,20 +276,23 @@ class RCTDiffusionPipeline(DiffusionPipeline):
268
  noise_batches = torch.reshape(noise_batches, (batch_size, self.num_channels, self.latent_size, self.latent_size))
269
  noise_batches = noise_batches.to('cuda')
270
  images = torch.Tensor(size=(batch_size, 3, self.sample_size, self.sample_size)).to('cuda')
 
271
 
272
- with torch.no_grad():
273
- image = noise_batches
274
- result = self.vae.decode(image).sample
275
- images = result
276
 
277
  # convert those tensors to PIL images
 
278
  output_images = []
279
  for batch_index in range(batch_size):
280
  image = images[batch_index]
281
- image = (image / 2 + 0.5).clamp(0, 1).squeeze()
282
- image = (image.permute(1, 2, 0) * 255).to(torch.uint8).cpu().numpy()
283
- image = (image * 255).round().astype("uint8")
284
- image = Image.fromarray(image)
 
285
  image.save(f'test{batch_index}.png')
286
  output_images.append(image)
287
 
 
172
 
173
  def generate_noise_batches(self, batch_size):
174
  noise_batches = torch.Tensor(size=(batch_size, self.num_channels, self.latent_size, self.latent_size)).to(dtype=torch.float16, device='cuda')
175
+ seed = int(0)
176
+ np.random.seed(seed)
177
+ torch.manual_seed(seed)
178
+ torch.cuda.manual_seed(seed)
179
  for batch_index in range(batch_size):
180
+ noise = torch.randn((self.num_channels, self.latent_size, self.latent_size)).to(dtype=torch.float16, device='cuda')
181
  noise_batches[batch_index] = noise
182
 
183
  return torch.reshape(noise_batches, (batch_size, self.num_channels, self.latent_size, self.latent_size)).to(dtype=torch.float16, device='cuda')
 
241
 
242
  def __call__(self, object_description : list[str], color1 : list[str], \
243
  color2 : list[str] = None, color3 : list[str] = None, \
244
+ batch_size=1, num_inference_steps=100, generator=torch.manual_seed(torch.random.seed())):
245
+
246
+ self.unet.to(device='cuda', dtype=torch.float16)
247
+ self.vae.to(device='cuda', dtype=torch.float16)
248
+ self.text_encoder.to(device='cuda', dtype=torch.float16)
249
 
250
  res, object_description, color1, color2, color3 = self.validate_inputs(object_description, color1, color2, color3, batch_size)
251
  if res == False:
 
276
  noise_batches = torch.reshape(noise_batches, (batch_size, self.num_channels, self.latent_size, self.latent_size))
277
  noise_batches = noise_batches.to('cuda')
278
  images = torch.Tensor(size=(batch_size, 3, self.sample_size, self.sample_size)).to('cuda')
279
+ images = noise_batches[:, :3]
280
 
281
+ #with torch.no_grad():
282
+ #image = noise_batches
283
+ #result = self.vae.decode(image).sample
284
+ #images = result
285
 
286
  # convert those tensors to PIL images
287
+ tensor_to_pil = T.ToPILImage()
288
  output_images = []
289
  for batch_index in range(batch_size):
290
  image = images[batch_index]
291
+ image = (image / 2 + 0.5).clamp(0, 1)
292
+ #image = (image.permute(1, 2, 0) * 255).to(torch.uint8).cpu().numpy()
293
+ #image = (image * 255).round().astype("uint8")
294
+ #image = Image.fromarray(image)
295
+ image = tensor_to_pil(image)
296
  image.save(f'test{batch_index}.png')
297
  output_images.append(image)
298
 
test_pipeline.py CHANGED
@@ -38,7 +38,9 @@ scheduler = DDPMScheduler(num_train_timesteps=20)
38
  vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", use_safetensors=True)
39
  vae = vae.to('cuda', dtype=torch.float16)
40
 
41
- pipeline = RCTDiffusionPipeline(unet, scheduler, vae, tokenizer, text_encoder)
42
- output = pipeline(['aleppo pine tree'], ['dark green'])
 
 
43
  pipeline.save_pretrained('test')
44
  print('test')
 
38
  vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", use_safetensors=True)
39
  vae = vae.to('cuda', dtype=torch.float16)
40
 
41
+ #pipeline = RCTDiffusionPipeline(unet, scheduler, vae, tokenizer, text_encoder)
42
+ pipeline = RCTDiffusionPipeline.from_pretrained('rct_foliage_999')
43
+ output = pipeline(['pagoda pine tree'], ['green'], ['grey'])
44
+ output[0].save('out.png')
45
  pipeline.save_pretrained('test')
46
  print('test')
train_model.py CHANGED
@@ -6,6 +6,7 @@ import numpy as np
6
  from rct_diffusion_pipeline import RCTDiffusionPipeline
7
  import torch
8
  import torchvision.transforms as T
 
9
  import torch.nn.functional as F
10
  from diffusers.optimization import get_cosine_schedule_with_warmup
11
  from tqdm.auto import tqdm
@@ -29,16 +30,21 @@ def save_and_test(pipeline, epoch):
29
  model_file = f'rct_foliage_{epoch}'
30
  pipeline.save_pretrained(model_file)
31
 
32
- def convert_images(dataset):
33
- preprocess = transforms.Compose(
34
- [
35
- transforms.Resize((LATENT_SIZE, LATENT_SIZE)),
36
- transforms.ToTensor(),
37
- transforms.Normalize([0.5], [0.5]),
38
- ]
39
- )
40
 
41
- images = [preprocess(image.convert("RGBA")) for image in dataset["image"]]
 
 
 
 
 
 
42
  object_descriptions = [obj_desc for obj_desc in dataset["object_description"]]
43
  colors1 = [color1 for color1 in dataset['color1']]
44
  colors2 = [color1 for color1 in dataset['color2']]
@@ -87,10 +93,10 @@ def create_embeddings(dataset, model):
87
  return model.test_generate_embeddings(object_descriptions, colors1, colors2, colors3)
88
 
89
 
90
- def train_model(batch_size=4, total_images=-1, epochs=100, scheduler_num_timesteps=20, save_model_interval=10, start_learning_rate=1e-3, lr_warmup_steps=1):
91
  dataset = load_dataset('frutiemax/rct_dataset', split=f'train[0:{total_images}]')
92
  dataset.set_transform(convert_images)
93
- num_images = int(dataset.num_rows / 4) if total_images == None else int(total_images / 4)
94
 
95
  unet = UNet2DConditionModel(sample_size=LATENT_SIZE, in_channels=LATENT_NUM_CHANNELS, out_channels=LATENT_NUM_CHANNELS, \
96
  down_block_types=("CrossAttnDownBlock2D","CrossAttnDownBlock2D","CrossAttnDownBlock2D", "DownBlock2D"),\
@@ -109,7 +115,7 @@ def train_model(batch_size=4, total_images=-1, epochs=100, scheduler_num_timeste
109
  "CompVis/stable-diffusion-v1-4", subfolder="text_encoder", use_safetensors=True
110
  ).to('cuda')
111
  vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", use_safetensors=True)
112
- vae = vae.to(dtype=torch.float16, device='cuda')
113
 
114
  optimizer = torch.optim.AdamW(unet.parameters(), lr=start_learning_rate)
115
  lr_scheduler = get_cosine_schedule_with_warmup(
@@ -134,7 +140,7 @@ def train_model(batch_size=4, total_images=-1, epochs=100, scheduler_num_timeste
134
  clean_images = batch['image']
135
  batch_size = clean_images.size(0)
136
  embeddings = create_embeddings(batch, model)
137
- clean_images = torch.reshape(clean_images, (batch['image'].size(0), LATENT_NUM_CHANNELS, LATENT_SIZE, LATENT_SIZE)).\
138
  to(device='cuda')
139
 
140
  noise = torch.randn(clean_images.shape, dtype=torch.float32, device='cuda')
@@ -146,9 +152,16 @@ def train_model(batch_size=4, total_images=-1, epochs=100, scheduler_num_timeste
146
  batch_embeddings = embeddings
147
  batch_embeddings = batch_embeddings.to('cuda')
148
 
 
 
 
149
  optimizer.zero_grad()
150
- unet_results = unet(noisy_images, timesteps, batch_embeddings).sample
151
- loss = loss_fn(unet_results, noise)
 
 
 
 
152
  loss.backward()
153
  optimizer.step()
154
  lr_scheduler.step()
@@ -171,4 +184,4 @@ def train_model(batch_size=4, total_images=-1, epochs=100, scheduler_num_timeste
171
 
172
 
173
  if __name__ == '__main__':
174
- train_model(1, total_images=4, save_model_interval=100, epochs=1000)
 
6
  from rct_diffusion_pipeline import RCTDiffusionPipeline
7
  import torch
8
  import torchvision.transforms as T
9
+ import torchvision
10
  import torch.nn.functional as F
11
  from diffusers.optimization import get_cosine_schedule_with_warmup
12
  from tqdm.auto import tqdm
 
30
  model_file = f'rct_foliage_{epoch}'
31
  pipeline.save_pretrained(model_file)
32
 
33
+ def transform_images(image):
34
+ res = torch.Tensor(SAMPLE_NUM_CHANNELS, SAMPLE_SIZE, SAMPLE_SIZE)
35
+ pil_to_tensor = T.PILToTensor()
36
+
37
+ res_index = 0
38
+ scale_factor = np.minimum(SAMPLE_SIZE / image.width, SAMPLE_SIZE / image.height)
39
+ image = Image.resize(image, size=(int(scale_factor * image.width), int(scale_factor * image.height)), resample=Resampling.NEAREST)
 
40
 
41
+ new_image = PIL.Image.new('RGB', (SAMPLE_SIZE, SAMPLE_SIZE))
42
+ new_image.paste(image, box=(int((SAMPLE_SIZE - image.width)/2), int((SAMPLE_SIZE - image.height)/2)))
43
+ res = pil_to_tensor(new_image)
44
+ return res
45
+
46
+ def convert_images(dataset):
47
+ images = [transform_images(image) for image in dataset["image"]]
48
  object_descriptions = [obj_desc for obj_desc in dataset["object_description"]]
49
  colors1 = [color1 for color1 in dataset['color1']]
50
  colors2 = [color1 for color1 in dataset['color2']]
 
93
  return model.test_generate_embeddings(object_descriptions, colors1, colors2, colors3)
94
 
95
 
96
+ def train_model(batch_size=4, total_images=-1, epochs=100, scheduler_num_timesteps=100, save_model_interval=10, start_learning_rate=1e-4, lr_warmup_steps=500):
97
  dataset = load_dataset('frutiemax/rct_dataset', split=f'train[0:{total_images}]')
98
  dataset.set_transform(convert_images)
99
+ num_images = dataset.num_rows
100
 
101
  unet = UNet2DConditionModel(sample_size=LATENT_SIZE, in_channels=LATENT_NUM_CHANNELS, out_channels=LATENT_NUM_CHANNELS, \
102
  down_block_types=("CrossAttnDownBlock2D","CrossAttnDownBlock2D","CrossAttnDownBlock2D", "DownBlock2D"),\
 
115
  "CompVis/stable-diffusion-v1-4", subfolder="text_encoder", use_safetensors=True
116
  ).to('cuda')
117
  vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", use_safetensors=True)
118
+ vae = vae.to(dtype=torch.float32, device='cuda')
119
 
120
  optimizer = torch.optim.AdamW(unet.parameters(), lr=start_learning_rate)
121
  lr_scheduler = get_cosine_schedule_with_warmup(
 
140
  clean_images = batch['image']
141
  batch_size = clean_images.size(0)
142
  embeddings = create_embeddings(batch, model)
143
+ clean_images = torch.reshape(clean_images, (batch['image'].size(0), SAMPLE_NUM_CHANNELS, SAMPLE_SIZE, SAMPLE_SIZE)).\
144
  to(device='cuda')
145
 
146
  noise = torch.randn(clean_images.shape, dtype=torch.float32, device='cuda')
 
152
  batch_embeddings = embeddings
153
  batch_embeddings = batch_embeddings.to('cuda')
154
 
155
+ # use the vae to get the latent images
156
+ latent_images = vae.encode(noisy_images).latent_dist.sample()
157
+
158
  optimizer.zero_grad()
159
+ unet_results = unet(latent_images, timesteps, batch_embeddings).sample
160
+
161
+ # get back the upscale result
162
+ noise_pred = vae.decode(unet_results).sample
163
+
164
+ loss = loss_fn(noise_pred, noise)
165
  loss.backward()
166
  optimizer.step()
167
  lr_scheduler.step()
 
184
 
185
 
186
  if __name__ == '__main__':
187
+ train_model(1, save_model_interval=10, epochs=100)