File size: 4,759 Bytes
5a67fb4
 
 
 
 
 
 
 
 
 
 
 
 
d872920
5a67fb4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import math

import numpy as np
import streamlit as st

import torch
import torch.nn.functional as F

import src.app.params as params
from src.models import ConditionalGenerator as InfoSCC_GAN
from src.models.big.BigGAN2 import Generator as BigGAN2Generator
from src.models import ConditionalDecoder as cVAE
from src.data import get_labels_train
from src.utils import sample_labels


device = params.device
size = params.size
n_layers = int(math.log2(size) - 2)
bs = 12
lin_space = torch.linspace(0, 1, bs).unsqueeze(1)
captions = [f'label_a * {(1 - x):.02f} + label_b * {x:.02f}' for x in lin_space.squeeze().numpy()]


@st.cache(allow_output_mutation=True)
def load_model(model_type: str):

    print(f'Loading model: {model_type}')
    if model_type == 'InfoSCC-GAN':
        g = InfoSCC_GAN(size=params.size,
                        y_size=params.shape_label,
                        z_size=params.noise_dim)
        ckpt = torch.load(params.path_infoscc_gan, map_location=torch.device('cpu'))
        g.load_state_dict(ckpt['g_ema'])
    elif model_type == 'BigGAN':
        g = BigGAN2Generator()
        ckpt = torch.load(params.path_biggan, map_location=torch.device('cpu'))
        g.load_state_dict(ckpt)
    elif model_type == 'cVAE':
        g = cVAE()
        ckpt = torch.load(params.path_cvae, map_location=torch.device('cpu'))
        g.load_state_dict(ckpt)
    else:
        raise ValueError('Unsupported model')
    g = g.eval().to(device=params.device)
    return g


@st.cache
def get_labels() -> torch.Tensor:
    path_labels = params.path_labels
    labels_train = get_labels_train(path_labels)
    return labels_train


def get_eps(n: int) -> torch.Tensor:
    eps = torch.randn((n, params.dim_z), device=device)
    return eps


def app():

    global lin_space, captions

    st.title('Interpolate Labels')
    st.markdown('This app allows the generation of the images with the labels that are interpolated between two labels.')
    st.markdown('In each row there are images generated with the same interpolated label by one of the models')

    biggan = load_model('BigGAN')
    infoscc_gan = load_model('InfoSCC-GAN')
    cvae = load_model('cVAE')
    labels_train = get_labels()

    # ==================== Labels ==============================================
    label_a = sample_labels(labels_train, n=1).repeat(bs, 1)
    label_b = sample_labels(labels_train, n=1).repeat(bs, 1)
    label_interpolated = (1 - lin_space) * label_a + lin_space * label_b

    sample_label = st.button('Sample label')
    if sample_label:
        label_a = sample_labels(labels_train, n=1).repeat(bs, 1)
        label_b = sample_labels(labels_train, n=1).repeat(bs, 1)
        label_interpolated = (1 - lin_space) * label_a + lin_space * label_b
    # ==================== Labels ==============================================

    # ==================== Noise ==============================================
    eps = get_eps(1).repeat(bs, 1)
    eps_infoscc = infoscc_gan.sample_eps(1).repeat(bs, 1)

    zs = np.array([[0.0] * params.n_basis] * n_layers, dtype=np.float32)
    zs_torch = torch.from_numpy(zs).unsqueeze(0).repeat(bs, 1, 1).to(device)

    st.subheader('Noise')
    st.markdown(r'Click on __Change eps__ button to change input $\varepsilon$ latent space')
    change_eps = st.button('Change eps')
    if change_eps:
        eps = get_eps(1).repeat(bs, 1)
        eps_infoscc = infoscc_gan.sample_eps(1).repeat(bs, 1)
    # ==================== Noise ==============================================

    with torch.no_grad():
        imgs_biggan = biggan(eps, label_interpolated).squeeze(0).cpu()
        imgs_infoscc = infoscc_gan(label_interpolated, eps_infoscc, zs_torch).squeeze(0).cpu()
        imgs_cvae = cvae(eps, label_interpolated).squeeze(0).cpu()

    if params.upsample:
        imgs_biggan = F.interpolate(imgs_biggan, (size * 4, size * 4), mode='bicubic')
        imgs_infoscc = F.interpolate(imgs_infoscc, (size * 4, size * 4), mode='bicubic')
        imgs_cvae = F.interpolate(imgs_cvae, (size * 4, size * 4), mode='bicubic')

    imgs_biggan = torch.clip(imgs_biggan, 0, 1)
    imgs_biggan = [(imgs_biggan[i].permute(1, 2, 0).numpy() * 255).astype(np.uint8) for i in range(bs)]
    imgs_infoscc = [(imgs_infoscc[i].permute(1, 2, 0).numpy() * 127.5 + 127.5).astype(np.uint8) for i in range(bs)]
    imgs_cvae = [(imgs_cvae[i].permute(1, 2, 0).numpy() * 127.5 + 127.5).astype(np.uint8) for i in range(bs)]

    c1, c2, c3 = st.columns(3)
    c1.header('BigGAN')
    c1.image(imgs_biggan, use_column_width=True, caption=captions)

    c2.header('InfoSCC-GAN')
    c2.image(imgs_infoscc, use_column_width=True, caption=captions)

    c3.header('cVAE')
    c3.image(imgs_cvae, use_column_width=True, caption=captions)