frutiemax commited on
Commit
4f25fc2
1 Parent(s): 2104644

Fix pipeline for new latent and sample sizes

Browse files
Files changed (1) hide show
  1. rct_diffusion_pipeline.py +8 -7
rct_diffusion_pipeline.py CHANGED
@@ -12,7 +12,7 @@ import pandas as pd
12
  from tqdm.auto import tqdm
13
 
14
  class RCTDiffusionPipeline(DiffusionPipeline):
15
- def __init__(self, unet, scheduler, vae):
16
  super().__init__()
17
 
18
  # dictionnary that keeps the different classes of object description, color1, color2 and color3
@@ -24,6 +24,8 @@ class RCTDiffusionPipeline(DiffusionPipeline):
24
  self.scheduler = scheduler
25
  self.unet = unet
26
  self.vae = vae
 
 
27
 
28
  # channels for 1 image
29
  self.num_channels = int(self.unet.config.in_channels / 4)
@@ -164,15 +166,14 @@ class RCTDiffusionPipeline(DiffusionPipeline):
164
  # now put those weights into a tensor
165
  return self.pack_labels_to_tensor(batch_size, object_descriptions, colors1, colors2, colors3).to(device='cuda',dtype=torch.float16)
166
 
167
- # generate 64x64 latents
168
  def generate_noise_batches(self, batch_size):
169
- noise_batches = torch.Tensor(size=(batch_size, 4, self.num_channels, 64, 64)).to(dtype=torch.float16, device='cuda')
170
  for batch_index in range(batch_size):
171
  for view_index in range(4):
172
- noise = torch.randn(self.num_channels, 64, 64).to(dtype=torch.float16, device='cuda')
173
  noise_batches[batch_index, view_index] = noise
174
 
175
- return torch.reshape(noise_batches, (batch_size, 1, self.num_channels*4, 64, 64)).to(dtype=torch.float16, device='cuda')
176
 
177
  def __call__(self, object_description : list[list[tuple[str, float]]], color1 : list[list[tuple[str, float]]], \
178
  color2 : list[list[tuple[str, float]]] = None, color3 : list[list[tuple[str, float]]] = None, \
@@ -201,8 +202,8 @@ class RCTDiffusionPipeline(DiffusionPipeline):
201
  epoch = epoch + 1
202
 
203
  # reshape the data so we get back 4 RGB images
204
- noise_batches = torch.reshape(noise_batches, (batch_size, 4, self.num_channels, 64, 64))
205
- images = torch.Tensor(size=(batch_size, 4, 3, 512, 512))
206
 
207
  with torch.no_grad():
208
  for image_index in range(4):
 
12
  from tqdm.auto import tqdm
13
 
14
  class RCTDiffusionPipeline(DiffusionPipeline):
15
+ def __init__(self, unet, scheduler, vae, latent_size=32, sample_size=256):
16
  super().__init__()
17
 
18
  # dictionnary that keeps the different classes of object description, color1, color2 and color3
 
24
  self.scheduler = scheduler
25
  self.unet = unet
26
  self.vae = vae
27
+ self.latent_size = latent_size
28
+ self.sample_size = sample_size
29
 
30
  # channels for 1 image
31
  self.num_channels = int(self.unet.config.in_channels / 4)
 
166
  # now put those weights into a tensor
167
  return self.pack_labels_to_tensor(batch_size, object_descriptions, colors1, colors2, colors3).to(device='cuda',dtype=torch.float16)
168
 
 
169
  def generate_noise_batches(self, batch_size):
170
+ noise_batches = torch.Tensor(size=(batch_size, 4, self.num_channels, self.latent_size, self.latent_size)).to(dtype=torch.float16, device='cuda')
171
  for batch_index in range(batch_size):
172
  for view_index in range(4):
173
+ noise = torch.randn(self.num_channels, self.latent_size, self.latent_size).to(dtype=torch.float16, device='cuda')
174
  noise_batches[batch_index, view_index] = noise
175
 
176
+ return torch.reshape(noise_batches, (batch_size, 1, self.num_channels*4, self.latent_size, self.latent_size)).to(dtype=torch.float16, device='cuda')
177
 
178
  def __call__(self, object_description : list[list[tuple[str, float]]], color1 : list[list[tuple[str, float]]], \
179
  color2 : list[list[tuple[str, float]]] = None, color3 : list[list[tuple[str, float]]] = None, \
 
202
  epoch = epoch + 1
203
 
204
  # reshape the data so we get back 4 RGB images
205
+ noise_batches = torch.reshape(noise_batches, (batch_size, 4, self.num_channels, self.latent_size, self.latent_size))
206
+ images = torch.Tensor(size=(batch_size, 4, 3, self.sample_size, self.sample_size))
207
 
208
  with torch.no_grad():
209
  for image_index in range(4):