Fix pipeline for new latent and sample sizes
Browse files
rct_diffusion_pipeline.py
CHANGED
@@ -12,7 +12,7 @@ import pandas as pd
|
|
12 |
from tqdm.auto import tqdm
|
13 |
|
14 |
class RCTDiffusionPipeline(DiffusionPipeline):
|
15 |
-
def __init__(self, unet, scheduler, vae):
|
16 |
super().__init__()
|
17 |
|
18 |
# dictionnary that keeps the different classes of object description, color1, color2 and color3
|
@@ -24,6 +24,8 @@ class RCTDiffusionPipeline(DiffusionPipeline):
|
|
24 |
self.scheduler = scheduler
|
25 |
self.unet = unet
|
26 |
self.vae = vae
|
|
|
|
|
27 |
|
28 |
# channels for 1 image
|
29 |
self.num_channels = int(self.unet.config.in_channels / 4)
|
@@ -164,15 +166,14 @@ class RCTDiffusionPipeline(DiffusionPipeline):
|
|
164 |
# now put those weights into a tensor
|
165 |
return self.pack_labels_to_tensor(batch_size, object_descriptions, colors1, colors2, colors3).to(device='cuda',dtype=torch.float16)
|
166 |
|
167 |
-
# generate 64x64 latents
|
168 |
def generate_noise_batches(self, batch_size):
|
169 |
-
noise_batches = torch.Tensor(size=(batch_size, 4, self.num_channels,
|
170 |
for batch_index in range(batch_size):
|
171 |
for view_index in range(4):
|
172 |
-
noise = torch.randn(self.num_channels,
|
173 |
noise_batches[batch_index, view_index] = noise
|
174 |
|
175 |
-
return torch.reshape(noise_batches, (batch_size, 1, self.num_channels*4,
|
176 |
|
177 |
def __call__(self, object_description : list[list[tuple[str, float]]], color1 : list[list[tuple[str, float]]], \
|
178 |
color2 : list[list[tuple[str, float]]] = None, color3 : list[list[tuple[str, float]]] = None, \
|
@@ -201,8 +202,8 @@ class RCTDiffusionPipeline(DiffusionPipeline):
|
|
201 |
epoch = epoch + 1
|
202 |
|
203 |
# reshape the data so we get back 4 RGB images
|
204 |
-
noise_batches = torch.reshape(noise_batches, (batch_size, 4, self.num_channels,
|
205 |
-
images = torch.Tensor(size=(batch_size, 4, 3,
|
206 |
|
207 |
with torch.no_grad():
|
208 |
for image_index in range(4):
|
|
|
12 |
from tqdm.auto import tqdm
|
13 |
|
14 |
class RCTDiffusionPipeline(DiffusionPipeline):
|
15 |
+
def __init__(self, unet, scheduler, vae, latent_size=32, sample_size=256):
|
16 |
super().__init__()
|
17 |
|
18 |
# dictionnary that keeps the different classes of object description, color1, color2 and color3
|
|
|
24 |
self.scheduler = scheduler
|
25 |
self.unet = unet
|
26 |
self.vae = vae
|
27 |
+
self.latent_size = latent_size
|
28 |
+
self.sample_size = sample_size
|
29 |
|
30 |
# channels for 1 image
|
31 |
self.num_channels = int(self.unet.config.in_channels / 4)
|
|
|
166 |
# now put those weights into a tensor
|
167 |
return self.pack_labels_to_tensor(batch_size, object_descriptions, colors1, colors2, colors3).to(device='cuda',dtype=torch.float16)
|
168 |
|
|
|
169 |
def generate_noise_batches(self, batch_size):
|
170 |
+
noise_batches = torch.Tensor(size=(batch_size, 4, self.num_channels, self.latent_size, self.latent_size)).to(dtype=torch.float16, device='cuda')
|
171 |
for batch_index in range(batch_size):
|
172 |
for view_index in range(4):
|
173 |
+
noise = torch.randn(self.num_channels, self.latent_size, self.latent_size).to(dtype=torch.float16, device='cuda')
|
174 |
noise_batches[batch_index, view_index] = noise
|
175 |
|
176 |
+
return torch.reshape(noise_batches, (batch_size, 1, self.num_channels*4, self.latent_size, self.latent_size)).to(dtype=torch.float16, device='cuda')
|
177 |
|
178 |
def __call__(self, object_description : list[list[tuple[str, float]]], color1 : list[list[tuple[str, float]]], \
|
179 |
color2 : list[list[tuple[str, float]]] = None, color3 : list[list[tuple[str, float]]] = None, \
|
|
|
202 |
epoch = epoch + 1
|
203 |
|
204 |
# reshape the data so we get back 4 RGB images
|
205 |
+
noise_batches = torch.reshape(noise_batches, (batch_size, 4, self.num_channels, self.latent_size, self.latent_size))
|
206 |
+
images = torch.Tensor(size=(batch_size, 4, 3, self.sample_size, self.sample_size))
|
207 |
|
208 |
with torch.no_grad():
|
209 |
for image_index in range(4):
|