File size: 1,296 Bytes
e161624
 
 
 
 
 
bc69298
e161624
 
 
ccd56bc
e161624
 
 
 
 
 
 
 
 
ccd56bc
e161624
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import streamlit as st
import torch.nn as nn
import torchvision.transforms as T
from torchvision.utils import make_grid
import torch

device = "cpu"

@torch.inference_mode()
def inference_gan():
    generator = torch.jit.load("mnist-G-torchscript.pt").to(device)
    x = torch.randn(30, 256, device='cuda')
    y = generator(x)
    y = y.view(-1, 1, 28, 28)  # reshape y to have 1 channel
    grid = make_grid(y.cpu().detach(), nrow=8)
    img = T.functional.to_pil_image(grid)
    return img

@torch.inference_mode()
def inference_dcgan():
    generator = torch.jit.load("animefacedataset-G2-torchscript.pt").to(device)
    def denorm(img_tensors):
        stats = (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)
        return img_tensors * stats[1][0] + stats[0][0]
    x = torch.randn(64, 128, 1, 1, device='cuda')
    y = generator(x)
    y = y.view(-1, 3, 64, 64)  # reshape y to have 3 channels
    grid = make_grid(denorm(y.cpu().detach()), nrow=8)
    img = T.functional.to_pil_image(grid)
    return img
def inference_both():
    inference_gan()
    inference_dcgan()

st.markdown("# Image Generation with GANs and DCGANs")
st.button("Generate Images", on_click=inference_both)
st.image(inference_dcgan(), caption="", use_column_width=True)
st.image(inference_gan(), caption="", use_column_width=True)