Spaces:
Running
on
Zero
Running
on
Zero
fix vae nan bug
Browse files
vqvae.py
CHANGED
@@ -56,9 +56,13 @@ class Autoencoder(nn.Module):
|
|
56 |
:param img: is the image tensor with shape `[batch_size, img_channels, img_height, img_width]`
|
57 |
"""
|
58 |
# Get embeddings with shape `[batch_size, z_channels * 2, z_height, z_height]`
|
|
|
|
|
59 |
z = self.encoder(img)
|
|
|
60 |
# Get the moments in the quantized embedding space
|
61 |
moments = self.quant_conv(z)
|
|
|
62 |
# Return the distribution
|
63 |
return GaussianDistribution(moments)
|
64 |
|
@@ -284,6 +288,7 @@ class GaussianDistribution:
|
|
284 |
`[batch_size, z_channels * 2, z_height, z_height]`
|
285 |
"""
|
286 |
# Split mean and log of variance
|
|
|
287 |
self.mean, log_var = torch.chunk(parameters, 2, dim=1)
|
288 |
# Clamp the log of variances
|
289 |
self.log_var = torch.clamp(log_var, -30.0, 20.0)
|
@@ -293,6 +298,8 @@ class GaussianDistribution:
|
|
293 |
|
294 |
def sample(self):
|
295 |
# Sample from the distribution
|
|
|
|
|
296 |
return self.mean + self.std * torch.randn_like(self.std)
|
297 |
|
298 |
def kl(self):
|
|
|
56 |
:param img: is the image tensor with shape `[batch_size, img_channels, img_height, img_width]`
|
57 |
"""
|
58 |
# Get embeddings with shape `[batch_size, z_channels * 2, z_height, z_height]`
|
59 |
+
print(f"encoder parameters max: {max([p.max() for p in self.encoder.parameters()])}")
|
60 |
+
print(f"encoder parameters min: {min([p.min() for p in self.encoder.parameters()])}")
|
61 |
z = self.encoder(img)
|
62 |
+
print(f"z.max(): {z.max()}, z.min(): {z.min()}")
|
63 |
# Get the moments in the quantized embedding space
|
64 |
moments = self.quant_conv(z)
|
65 |
+
print(f"moments.max(): {moments.max()}, moments.min(): {moments.min()}")
|
66 |
# Return the distribution
|
67 |
return GaussianDistribution(moments)
|
68 |
|
|
|
288 |
`[batch_size, z_channels * 2, z_height, z_height]`
|
289 |
"""
|
290 |
# Split mean and log of variance
|
291 |
+
print(f"parameters.max(): {parameters.max()}, parameters.min(): {parameters.min()}")
|
292 |
self.mean, log_var = torch.chunk(parameters, 2, dim=1)
|
293 |
# Clamp the log of variances
|
294 |
self.log_var = torch.clamp(log_var, -30.0, 20.0)
|
|
|
298 |
|
299 |
def sample(self):
|
300 |
# Sample from the distribution
|
301 |
+
print(f"self.mean.max(): {self.mean.max()}, self.mean.min(): {self.mean.min()}")
|
302 |
+
print(f"self.std.max(): {self.std.max()}, self.std.min(): {self.std.min()}")
|
303 |
return self.mean + self.std * torch.randn_like(self.std)
|
304 |
|
305 |
def kl(self):
|