minhalvp commited on
Commit
abf7df3
·
1 Parent(s): bc69298

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -4
app.py CHANGED
@@ -9,10 +9,10 @@ device = "cpu"
9
  @torch.inference_mode()
10
  def inference_gan():
11
  generator = torch.jit.load("mnist-G-torchscript.pt").to(device)
12
- x = torch.randn(30, 256, device='cuda')
13
  y = generator(x)
14
  y = y.view(-1, 1, 28, 28) # reshape y to have 1 channel
15
- grid = make_grid(y.cpu().detach(), nrow=8)
16
  img = T.functional.to_pil_image(grid)
17
  return img
18
 
@@ -22,10 +22,10 @@ def inference_dcgan():
22
  def denorm(img_tensors):
23
  stats = (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)
24
  return img_tensors * stats[1][0] + stats[0][0]
25
- x = torch.randn(64, 128, 1, 1, device='cuda')
26
  y = generator(x)
27
  y = y.view(-1, 3, 64, 64) # reshape y to have 3 channels
28
- grid = make_grid(denorm(y.cpu().detach()), nrow=8)
29
  img = T.functional.to_pil_image(grid)
30
  return img
31
  def inference_both():
 
9
  @torch.inference_mode()
10
  def inference_gan():
11
  generator = torch.jit.load("mnist-G-torchscript.pt").to(device)
12
+ x = torch.randn(30, 256, device=device)
13
  y = generator(x)
14
  y = y.view(-1, 1, 28, 28) # reshape y to have 1 channel
15
+ grid = make_grid(y.detach(), nrow=8)
16
  img = T.functional.to_pil_image(grid)
17
  return img
18
 
 
22
  def denorm(img_tensors):
23
  stats = (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)
24
  return img_tensors * stats[1][0] + stats[0][0]
25
+ x = torch.randn(64, 128, 1, 1, device=device)
26
  y = generator(x)
27
  y = y.view(-1, 3, 64, 64) # reshape y to have 3 channels
28
+ grid = make_grid(denorm(y.detach()), nrow=8)
29
  img = T.functional.to_pil_image(grid)
30
  return img
31
  def inference_both():