hysts HF staff commited on
Commit
0264baa
·
1 Parent(s): a84e7ee
Files changed (1) hide show
  1. app.py +18 -16
app.py CHANGED
@@ -2,7 +2,6 @@
2
 
3
  from __future__ import annotations
4
 
5
- import functools
6
  import os
7
  import random
8
  import shlex
@@ -58,12 +57,16 @@ def load_model(device: torch.device) -> nn.Module:
58
  return model
59
 
60
 
61
- def generate_z(z_dim: int, seed: int, device: torch.device) -> torch.Tensor:
62
- return torch.from_numpy(np.random.RandomState(seed).randn(1, z_dim)).to(device).float()
 
 
 
 
63
 
64
 
65
  @torch.inference_mode()
66
- def generate_image(model: nn.Module, z: torch.Tensor, truncation_psi: float, randomize_noise: bool) -> np.ndarray:
67
  out, _ = model([z], truncation=truncation_psi, truncation_latent=model.latent_avg, randomize_noise=randomize_noise)
68
  out = (out.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
69
  return out[0].cpu().numpy()
@@ -77,14 +80,15 @@ def generate_interpolated_images(
77
  psi0: float,
78
  psi1: float,
79
  randomize_noise: bool,
80
- model: nn.Module,
81
- device: torch.device,
82
  ) -> list[np.ndarray]:
83
  seed0 = int(np.clip(seed0, 0, MAX_SEED))
84
  seed1 = int(np.clip(seed1, 0, MAX_SEED))
85
 
86
- z0 = generate_z(model.style_dim, seed0, device)
87
- z1 = generate_z(model.style_dim, seed1, device)
 
 
 
88
  vec = z1 - z0
89
  dvec = vec / (num_intermediate + 1)
90
  zs = [z0 + dvec * i for i in range(num_intermediate + 2)]
@@ -92,15 +96,11 @@ def generate_interpolated_images(
92
  psis = [psi0 + dpsi * i for i in range(num_intermediate + 2)]
93
  res = []
94
  for z, psi in zip(zs, psis):
95
- out = generate_image(model, z, psi, randomize_noise)
96
  res.append(out)
97
  return res
98
 
99
 
100
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
101
- model = load_model(device)
102
- fn = functools.partial(generate_interpolated_images, model=model, device=device)
103
-
104
  examples = [
105
  [29703, 55376, 3, 0.7, 0.7, False],
106
  [34141, 36864, 5, 0.7, 0.7, False],
@@ -141,13 +141,15 @@ with gr.Blocks(css="style.css") as demo:
141
  examples=examples,
142
  inputs=inputs,
143
  outputs=result,
144
- fn=fn,
145
  cache_examples=os.getenv("CACHE_EXAMPLES") == "1",
146
  )
147
  run_button.click(
148
- fn=fn,
149
  inputs=inputs,
150
  outputs=result,
151
  api_name="run",
152
  )
153
- demo.queue(max_size=10).launch()
 
 
 
2
 
3
  from __future__ import annotations
4
 
 
5
  import os
6
  import random
7
  import shlex
 
57
  return model
58
 
59
 
60
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
61
+ model = load_model(device)
62
+
63
+
64
+ def generate_z(z_dim: int, seed: int) -> torch.Tensor:
65
+ return torch.from_numpy(np.random.RandomState(seed).randn(1, z_dim)).float()
66
 
67
 
68
  @torch.inference_mode()
69
+ def generate_image(z: torch.Tensor, truncation_psi: float, randomize_noise: bool) -> np.ndarray:
70
  out, _ = model([z], truncation=truncation_psi, truncation_latent=model.latent_avg, randomize_noise=randomize_noise)
71
  out = (out.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
72
  return out[0].cpu().numpy()
 
80
  psi0: float,
81
  psi1: float,
82
  randomize_noise: bool,
 
 
83
  ) -> list[np.ndarray]:
84
  seed0 = int(np.clip(seed0, 0, MAX_SEED))
85
  seed1 = int(np.clip(seed1, 0, MAX_SEED))
86
 
87
+ z0 = generate_z(model.style_dim, seed0)
88
+ z1 = generate_z(model.style_dim, seed1)
89
+ z0 = z0.to(device)
90
+ z1 = z1.to(device)
91
+
92
  vec = z1 - z0
93
  dvec = vec / (num_intermediate + 1)
94
  zs = [z0 + dvec * i for i in range(num_intermediate + 2)]
 
96
  psis = [psi0 + dpsi * i for i in range(num_intermediate + 2)]
97
  res = []
98
  for z, psi in zip(zs, psis):
99
+ out = generate_image(z, psi, randomize_noise)
100
  res.append(out)
101
  return res
102
 
103
 
 
 
 
 
104
  examples = [
105
  [29703, 55376, 3, 0.7, 0.7, False],
106
  [34141, 36864, 5, 0.7, 0.7, False],
 
141
  examples=examples,
142
  inputs=inputs,
143
  outputs=result,
144
+ fn=generate_interpolated_images,
145
  cache_examples=os.getenv("CACHE_EXAMPLES") == "1",
146
  )
147
  run_button.click(
148
+ fn=generate_interpolated_images,
149
  inputs=inputs,
150
  outputs=result,
151
  api_name="run",
152
  )
153
+
154
+ if __name__ == "__main__":
155
+ demo.queue(max_size=10).launch()