Chaerin5 commited on
Commit
c5c1afa
·
1 Parent(s): 50fb683

fix vae nan bug

Browse files
Files changed (1) hide show
  1. vqvae.py +7 -0
vqvae.py CHANGED
@@ -56,9 +56,13 @@ class Autoencoder(nn.Module):
56
  :param img: is the image tensor with shape `[batch_size, img_channels, img_height, img_width]`
57
  """
58
  # Get embeddings with shape `[batch_size, z_channels * 2, z_height, z_height]`
 
 
59
  z = self.encoder(img)
 
60
  # Get the moments in the quantized embedding space
61
  moments = self.quant_conv(z)
 
62
  # Return the distribution
63
  return GaussianDistribution(moments)
64
 
@@ -284,6 +288,7 @@ class GaussianDistribution:
284
  `[batch_size, z_channels * 2, z_height, z_height]`
285
  """
286
  # Split mean and log of variance
 
287
  self.mean, log_var = torch.chunk(parameters, 2, dim=1)
288
  # Clamp the log of variances
289
  self.log_var = torch.clamp(log_var, -30.0, 20.0)
@@ -293,6 +298,8 @@ class GaussianDistribution:
293
 
294
  def sample(self):
295
  # Sample from the distribution
 
 
296
  return self.mean + self.std * torch.randn_like(self.std)
297
 
298
  def kl(self):
 
56
  :param img: is the image tensor with shape `[batch_size, img_channels, img_height, img_width]`
57
  """
58
  # Get embeddings with shape `[batch_size, z_channels * 2, z_height, z_height]`
59
+ print(f"encoder parameters max: {max([p.max() for p in self.encoder.parameters()])}")
60
+ print(f"encoder parameters min: {min([p.min() for p in self.encoder.parameters()])}")
61
  z = self.encoder(img)
62
+ print(f"z.max(): {z.max()}, z.min(): {z.min()}")
63
  # Get the moments in the quantized embedding space
64
  moments = self.quant_conv(z)
65
+ print(f"moments.max(): {moments.max()}, moments.min(): {moments.min()}")
66
  # Return the distribution
67
  return GaussianDistribution(moments)
68
 
 
288
  `[batch_size, z_channels * 2, z_height, z_height]`
289
  """
290
  # Split mean and log of variance
291
+ print(f"parameters.max(): {parameters.max()}, parameters.min(): {parameters.min()}")
292
  self.mean, log_var = torch.chunk(parameters, 2, dim=1)
293
  # Clamp the log of variances
294
  self.log_var = torch.clamp(log_var, -30.0, 20.0)
 
298
 
299
  def sample(self):
300
  # Sample from the distribution
301
+ print(f"self.mean.max(): {self.mean.max()}, self.mean.min(): {self.mean.min()}")
302
+ print(f"self.std.max(): {self.std.max()}, self.std.min(): {self.std.min()}")
303
  return self.mean + self.std * torch.randn_like(self.std)
304
 
305
  def kl(self):