frutiemax commited on
Commit
3ab0859
·
1 Parent(s): a60f6bb

Dont decode in the training phase

Browse files
Files changed (2) hide show
  1. rct_diffusion_pipeline.py +5 -4
  2. train_model.py +11 -10
rct_diffusion_pipeline.py CHANGED
@@ -278,10 +278,11 @@ class RCTDiffusionPipeline(DiffusionPipeline):
278
  images = torch.Tensor(size=(batch_size, 3, self.sample_size, self.sample_size)).to('cuda')
279
  images = noise_batches[:, :3]
280
 
281
- #with torch.no_grad():
282
- #image = noise_batches
283
- #result = self.vae.decode(image).sample
284
- #images = result
 
285
 
286
  # convert those tensors to PIL images
287
  tensor_to_pil = T.ToPILImage()
 
278
  images = torch.Tensor(size=(batch_size, 3, self.sample_size, self.sample_size)).to('cuda')
279
  images = noise_batches[:, :3]
280
 
281
+ with torch.no_grad():
282
+ image = noise_batches
283
+ result = self.vae.decode(image).sample
284
+ images = result
285
+ images = images / self.vae.config.scaling_factor
286
 
287
  # convert those tensors to PIL images
288
  tensor_to_pil = T.ToPILImage()
train_model.py CHANGED
@@ -124,6 +124,9 @@ def train_model(batch_size=4, total_images=-1, epochs=100, scheduler_num_timeste
124
  vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", use_safetensors=True)
125
  vae = vae.to(dtype=torch.float32, device='cuda')
126
 
 
 
 
127
  optimizer = torch.optim.AdamW(unet.parameters(), lr=start_learning_rate)
128
  lr_scheduler = get_cosine_schedule_with_warmup(
129
  optimizer=optimizer,
@@ -149,26 +152,24 @@ def train_model(batch_size=4, total_images=-1, epochs=100, scheduler_num_timeste
149
  embeddings = create_embeddings(batch, model)
150
  clean_images = torch.reshape(clean_images, (batch['image'].size(0), SAMPLE_NUM_CHANNELS, SAMPLE_SIZE, SAMPLE_SIZE)).\
151
  to(device='cuda')
 
 
 
 
152
 
153
- noise = torch.randn(clean_images.shape, dtype=torch.float32, device='cuda')
154
  timesteps = torch.randint(0, scheduler.config.num_train_timesteps, (batch_size, )).to(device='cuda')
155
 
156
  #timesteps = timesteps.to(dtype=torch.int, device='cuda')
157
- noisy_images = scheduler.add_noise(clean_images, noise, timesteps)
158
 
159
  batch_embeddings = embeddings
160
  batch_embeddings = batch_embeddings.to('cuda')
161
 
162
- # use the vae to get the latent images
163
- latent_images = vae.encode(noisy_images).latent_dist.sample()
164
-
165
  optimizer.zero_grad()
166
- unet_results = unet(latent_images, timesteps, batch_embeddings).sample
167
-
168
- # get back the upscale result
169
- noise_pred = vae.decode(unet_results).sample
170
 
171
- loss = loss_fn(noise_pred, noise)
172
  loss.backward()
173
  optimizer.step()
174
  lr_scheduler.step()
 
124
  vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", use_safetensors=True)
125
  vae = vae.to(dtype=torch.float32, device='cuda')
126
 
127
+ vae.requires_grad_(False)
128
+ text_encoder.requires_grad_(False)
129
+
130
  optimizer = torch.optim.AdamW(unet.parameters(), lr=start_learning_rate)
131
  lr_scheduler = get_cosine_schedule_with_warmup(
132
  optimizer=optimizer,
 
152
  embeddings = create_embeddings(batch, model)
153
  clean_images = torch.reshape(clean_images, (batch['image'].size(0), SAMPLE_NUM_CHANNELS, SAMPLE_SIZE, SAMPLE_SIZE)).\
154
  to(device='cuda')
155
+
156
+ # use the vae to get the latent images
157
+ latent_images = vae.encode(clean_images).latent_dist.sample()
158
+ latent_images = latent_images * vae.config.scaling_factor
159
 
160
+ noise = torch.randn_like(latent_images)
161
  timesteps = torch.randint(0, scheduler.config.num_train_timesteps, (batch_size, )).to(device='cuda')
162
 
163
  #timesteps = timesteps.to(dtype=torch.int, device='cuda')
164
+ noisy_images = scheduler.add_noise(latent_images, noise, timesteps)
165
 
166
  batch_embeddings = embeddings
167
  batch_embeddings = batch_embeddings.to('cuda')
168
 
 
 
 
169
  optimizer.zero_grad()
170
+ unet_results = unet(noisy_images, timesteps, batch_embeddings).sample
 
 
 
171
 
172
+ loss = loss_fn(unet_results, noise)
173
  loss.backward()
174
  optimizer.step()
175
  lr_scheduler.step()