Dont decode in the training phase
Browse files- rct_diffusion_pipeline.py +5 -4
- train_model.py +11 -10
rct_diffusion_pipeline.py
CHANGED
@@ -278,10 +278,11 @@ class RCTDiffusionPipeline(DiffusionPipeline):
|
|
278 |
images = torch.Tensor(size=(batch_size, 3, self.sample_size, self.sample_size)).to('cuda')
|
279 |
images = noise_batches[:, :3]
|
280 |
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
|
|
|
285 |
|
286 |
# convert those tensors to PIL images
|
287 |
tensor_to_pil = T.ToPILImage()
|
|
|
278 |
images = torch.Tensor(size=(batch_size, 3, self.sample_size, self.sample_size)).to('cuda')
|
279 |
images = noise_batches[:, :3]
|
280 |
|
281 |
+
with torch.no_grad():
|
282 |
+
image = noise_batches
|
283 |
+
result = self.vae.decode(image).sample
|
284 |
+
images = result
|
285 |
+
images = images / self.vae.config.scaling_factor
|
286 |
|
287 |
# convert those tensors to PIL images
|
288 |
tensor_to_pil = T.ToPILImage()
|
train_model.py
CHANGED
@@ -124,6 +124,9 @@ def train_model(batch_size=4, total_images=-1, epochs=100, scheduler_num_timeste
|
|
124 |
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", use_safetensors=True)
|
125 |
vae = vae.to(dtype=torch.float32, device='cuda')
|
126 |
|
|
|
|
|
|
|
127 |
optimizer = torch.optim.AdamW(unet.parameters(), lr=start_learning_rate)
|
128 |
lr_scheduler = get_cosine_schedule_with_warmup(
|
129 |
optimizer=optimizer,
|
@@ -149,26 +152,24 @@ def train_model(batch_size=4, total_images=-1, epochs=100, scheduler_num_timeste
|
|
149 |
embeddings = create_embeddings(batch, model)
|
150 |
clean_images = torch.reshape(clean_images, (batch['image'].size(0), SAMPLE_NUM_CHANNELS, SAMPLE_SIZE, SAMPLE_SIZE)).\
|
151 |
to(device='cuda')
|
|
|
|
|
|
|
|
|
152 |
|
153 |
-
noise = torch.
|
154 |
timesteps = torch.randint(0, scheduler.config.num_train_timesteps, (batch_size, )).to(device='cuda')
|
155 |
|
156 |
#timesteps = timesteps.to(dtype=torch.int, device='cuda')
|
157 |
-
noisy_images = scheduler.add_noise(
|
158 |
|
159 |
batch_embeddings = embeddings
|
160 |
batch_embeddings = batch_embeddings.to('cuda')
|
161 |
|
162 |
-
# use the vae to get the latent images
|
163 |
-
latent_images = vae.encode(noisy_images).latent_dist.sample()
|
164 |
-
|
165 |
optimizer.zero_grad()
|
166 |
-
unet_results = unet(
|
167 |
-
|
168 |
-
# get back the upscale result
|
169 |
-
noise_pred = vae.decode(unet_results).sample
|
170 |
|
171 |
-
loss = loss_fn(
|
172 |
loss.backward()
|
173 |
optimizer.step()
|
174 |
lr_scheduler.step()
|
|
|
124 |
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", use_safetensors=True)
|
125 |
vae = vae.to(dtype=torch.float32, device='cuda')
|
126 |
|
127 |
+
vae.requires_grad_(False)
|
128 |
+
text_encoder.requires_grad_(False)
|
129 |
+
|
130 |
optimizer = torch.optim.AdamW(unet.parameters(), lr=start_learning_rate)
|
131 |
lr_scheduler = get_cosine_schedule_with_warmup(
|
132 |
optimizer=optimizer,
|
|
|
152 |
embeddings = create_embeddings(batch, model)
|
153 |
clean_images = torch.reshape(clean_images, (batch['image'].size(0), SAMPLE_NUM_CHANNELS, SAMPLE_SIZE, SAMPLE_SIZE)).\
|
154 |
to(device='cuda')
|
155 |
+
|
156 |
+
# use the vae to get the latent images
|
157 |
+
latent_images = vae.encode(clean_images).latent_dist.sample()
|
158 |
+
latent_images = latent_images * vae.config.scaling_factor
|
159 |
|
160 |
+
noise = torch.randn_like(latent_images)
|
161 |
timesteps = torch.randint(0, scheduler.config.num_train_timesteps, (batch_size, )).to(device='cuda')
|
162 |
|
163 |
#timesteps = timesteps.to(dtype=torch.int, device='cuda')
|
164 |
+
noisy_images = scheduler.add_noise(latent_images, noise, timesteps)
|
165 |
|
166 |
batch_embeddings = embeddings
|
167 |
batch_embeddings = batch_embeddings.to('cuda')
|
168 |
|
|
|
|
|
|
|
169 |
optimizer.zero_grad()
|
170 |
+
unet_results = unet(noisy_images, timesteps, batch_embeddings).sample
|
|
|
|
|
|
|
171 |
|
172 |
+
loss = loss_fn(unet_results, noise)
|
173 |
loss.backward()
|
174 |
optimizer.step()
|
175 |
lr_scheduler.step()
|