Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -9,10 +9,10 @@ device = "cpu"
|
|
9 |
@torch.inference_mode()
|
10 |
def inference_gan():
|
11 |
generator = torch.jit.load("mnist-G-torchscript.pt").to(device)
|
12 |
-
x = torch.randn(30, 256, device=
|
13 |
y = generator(x)
|
14 |
y = y.view(-1, 1, 28, 28) # reshape y to have 1 channel
|
15 |
-
grid = make_grid(y.
|
16 |
img = T.functional.to_pil_image(grid)
|
17 |
return img
|
18 |
|
@@ -22,10 +22,10 @@ def inference_dcgan():
|
|
22 |
def denorm(img_tensors):
|
23 |
stats = (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)
|
24 |
return img_tensors * stats[1][0] + stats[0][0]
|
25 |
-
x = torch.randn(64, 128, 1, 1, device=
|
26 |
y = generator(x)
|
27 |
y = y.view(-1, 3, 64, 64) # reshape y to have 3 channels
|
28 |
-
grid = make_grid(denorm(y.
|
29 |
img = T.functional.to_pil_image(grid)
|
30 |
return img
|
31 |
def inference_both():
|
|
|
9 |
@torch.inference_mode()
|
10 |
def inference_gan():
|
11 |
generator = torch.jit.load("mnist-G-torchscript.pt").to(device)
|
12 |
+
x = torch.randn(30, 256, device=device)
|
13 |
y = generator(x)
|
14 |
y = y.view(-1, 1, 28, 28) # reshape y to have 1 channel
|
15 |
+
grid = make_grid(y.detach(), nrow=8)
|
16 |
img = T.functional.to_pil_image(grid)
|
17 |
return img
|
18 |
|
|
|
22 |
def denorm(img_tensors):
|
23 |
stats = (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)
|
24 |
return img_tensors * stats[1][0] + stats[0][0]
|
25 |
+
x = torch.randn(64, 128, 1, 1, device=device)
|
26 |
y = generator(x)
|
27 |
y = y.view(-1, 3, 64, 64) # reshape y to have 3 channels
|
28 |
+
grid = make_grid(denorm(y.detach()), nrow=8)
|
29 |
img = T.functional.to_pil_image(grid)
|
30 |
return img
|
31 |
def inference_both():
|