hysts HF staff commited on
Commit
69c1172
β€’
1 Parent(s): b2db920
Files changed (3) hide show
  1. README.md +1 -1
  2. app.py +26 -22
  3. requirements.txt +3 -3
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: 🌍
4
  colorFrom: purple
5
  colorTo: yellow
6
  sdk: gradio
7
- sdk_version: 3.36.1
8
  app_file: app.py
9
  pinned: false
10
  suggested_hardware: t4-small
 
4
  colorFrom: purple
5
  colorTo: yellow
6
  sdk: gradio
7
+ sdk_version: 4.36.0
8
  app_file: app.py
9
  pinned: false
10
  suggested_hardware: t4-small
app.py CHANGED
@@ -2,7 +2,6 @@
2
 
3
  from __future__ import annotations
4
 
5
- import functools
6
  import pickle
7
  import sys
8
 
@@ -18,22 +17,6 @@ TITLE = "StyleGAN-Human"
18
  DESCRIPTION = "https://github.com/stylegan-human/StyleGAN-Human"
19
 
20
 
21
- def generate_z(z_dim: int, seed: int, device: torch.device) -> torch.Tensor:
22
- return torch.from_numpy(np.random.RandomState(seed).randn(1, z_dim)).to(device).float()
23
-
24
-
25
- @torch.inference_mode()
26
- def generate_image(seed: int, truncation_psi: float, model: nn.Module, device: torch.device) -> np.ndarray:
27
- seed = int(np.clip(seed, 0, np.iinfo(np.uint32).max))
28
-
29
- z = generate_z(model.z_dim, seed, device)
30
- label = torch.zeros([1, model.c_dim], device=device)
31
-
32
- out = model(z, label, truncation_psi=truncation_psi, force_fp32=True)
33
- out = (out.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
34
- return out[0].cpu().numpy()
35
-
36
-
37
  def load_model(file_name: str, device: torch.device) -> nn.Module:
38
  path = hf_hub_download("public-data/StyleGAN-Human", f"models/{file_name}")
39
  with open(path, "rb") as f:
@@ -49,15 +32,36 @@ def load_model(file_name: str, device: torch.device) -> nn.Module:
49
 
50
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
51
  model = load_model("stylegan_human_v2_1024.pkl", device)
52
- fn = functools.partial(generate_image, model=model, device=device)
53
 
54
- gr.Interface(
55
- fn=fn,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  inputs=[
57
  gr.Slider(label="Seed", minimum=0, maximum=100000, step=1, value=0),
58
  gr.Slider(label="Truncation psi", minimum=0, maximum=2, step=0.05, value=0.7),
59
  ],
60
- outputs=gr.Image(label="Output", type="numpy"),
61
  title=TITLE,
62
  description=DESCRIPTION,
63
- ).queue(max_size=10).launch()
 
 
 
 
 
2
 
3
  from __future__ import annotations
4
 
 
5
  import pickle
6
  import sys
7
 
 
17
  DESCRIPTION = "https://github.com/stylegan-human/StyleGAN-Human"
18
 
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  def load_model(file_name: str, device: torch.device) -> nn.Module:
21
  path = hf_hub_download("public-data/StyleGAN-Human", f"models/{file_name}")
22
  with open(path, "rb") as f:
 
32
 
33
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
34
  model = load_model("stylegan_human_v2_1024.pkl", device)
 
35
 
36
+
37
+ def generate_z(z_dim: int, seed: int) -> torch.Tensor:
38
+ return torch.from_numpy(np.random.RandomState(seed).randn(1, z_dim)).float()
39
+
40
+
41
+ @torch.inference_mode()
42
+ def generate_image(seed: int, truncation_psi: float) -> np.ndarray:
43
+ seed = int(np.clip(seed, 0, np.iinfo(np.uint32).max))
44
+
45
+ z = generate_z(model.z_dim, seed)
46
+ z = z.to(device)
47
+ label = torch.zeros([1, model.c_dim], device=device)
48
+
49
+ out = model(z, label, truncation_psi=truncation_psi, force_fp32=True)
50
+ out = (out.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
51
+ return out[0].cpu().numpy()
52
+
53
+
54
+ demo = gr.Interface(
55
+ fn=generate_image,
56
  inputs=[
57
  gr.Slider(label="Seed", minimum=0, maximum=100000, step=1, value=0),
58
  gr.Slider(label="Truncation psi", minimum=0, maximum=2, step=0.05, value=0.7),
59
  ],
60
+ outputs=gr.Image(label="Output"),
61
  title=TITLE,
62
  description=DESCRIPTION,
63
+ )
64
+
65
+
66
+ if __name__ == "__main__":
67
+ demo.queue(max_size=10).launch()
requirements.txt CHANGED
@@ -1,5 +1,5 @@
1
- numpy==1.23.5
2
- Pillow==10.0.0
3
- scipy==1.10.1
4
  torch==2.0.1
5
  torchvision==0.15.2
 
1
+ numpy==1.26.4
2
+ Pillow==10.3.0
3
+ scipy==1.13.1
4
  torch==2.0.1
5
  torchvision==0.15.2