Spaces:
Running
on
T4
Running
on
T4
Update Modules/ControllabilityGAN/GAN.py
Browse files
Modules/ControllabilityGAN/GAN.py
CHANGED
@@ -5,7 +5,7 @@ from Modules.ControllabilityGAN.wgan.init_wgan import create_wgan
|
|
5 |
|
6 |
class GanWrapper:
|
7 |
|
8 |
-
def __init__(self, path_wgan, device):
|
9 |
self.device = device
|
10 |
self.path_wgan = path_wgan
|
11 |
|
@@ -20,15 +20,18 @@ class GanWrapper:
|
|
20 |
|
21 |
self.z_list = list()
|
22 |
|
23 |
-
|
24 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
self.z = self.z_list[0]
|
26 |
|
27 |
def set_latent(self, seed):
|
28 |
-
self.z = self.
|
29 |
-
|
30 |
-
def reset_default_latent(self):
|
31 |
-
self.z = self.wgan.G.sample_latent(1, self.wgan.G.z_dim, temperature=0.8)
|
32 |
|
33 |
def load_model(self, path):
|
34 |
gan_checkpoint = torch.load(path, map_location="cpu")
|
@@ -53,7 +56,7 @@ class GanWrapper:
|
|
53 |
self.mean = gan_checkpoint["dataset_mean"]
|
54 |
self.std = gan_checkpoint["dataset_std"]
|
55 |
|
56 |
-
def compute_controllability(self, n_samples=
|
57 |
_, intermediate, z = self.wgan.sample_generator(num_samples=n_samples, nograd=True, return_intermediate=True)
|
58 |
intermediate = intermediate.cpu()
|
59 |
z = z.cpu()
|
|
|
5 |
|
6 |
class GanWrapper:
|
7 |
|
8 |
+
def __init__(self, path_wgan, device, num_cached_voices=10):
|
9 |
self.device = device
|
10 |
self.path_wgan = path_wgan
|
11 |
|
|
|
20 |
|
21 |
self.z_list = list()
|
22 |
|
23 |
+
while len(self.z_list) < num_cached_voices + 2:
|
24 |
+
z = self.wgan.G.sample_latent(1, self.wgan.G.z_dim, temperature=0.8)
|
25 |
+
sims = [-1.0]
|
26 |
+
for other_z in self.z_list:
|
27 |
+
sims.append(torch.nn.functional.cosine_similarity(z, other_z))
|
28 |
+
print(max(sims), len(self.z_list))
|
29 |
+
if max(sims) < 0.25:
|
30 |
+
self.z_list.append(z)
|
31 |
self.z = self.z_list[0]
|
32 |
|
33 |
def set_latent(self, seed):
|
34 |
+
self.z = self.z_list[seed]
|
|
|
|
|
|
|
35 |
|
36 |
def load_model(self, path):
|
37 |
gan_checkpoint = torch.load(path, map_location="cpu")
|
|
|
56 |
self.mean = gan_checkpoint["dataset_mean"]
|
57 |
self.std = gan_checkpoint["dataset_std"]
|
58 |
|
59 |
+
def compute_controllability(self, n_samples=200000):
|
60 |
_, intermediate, z = self.wgan.sample_generator(num_samples=n_samples, nograd=True, return_intermediate=True)
|
61 |
intermediate = intermediate.cpu()
|
62 |
z = z.cpu()
|