SOAT / app.py
Ahsen Khaliq
Update app.py
871d6aa
import os
os.system("git clone https://github.com/mchong6/SOAT.git")
import sys
sys.path.append("SOAT")
import os
import torch
import torchvision
from torch import nn
import numpy as np
import torch.backends.cudnn as cudnn
cudnn.benchmark = True
import math
import matplotlib.pyplot as plt
import torch.nn.functional as F
from model import *
from tqdm import tqdm as tqdm
import pickle
from copy import deepcopy
import warnings
warnings.filterwarnings("ignore", category=UserWarning) # get rid of interpolation warning
import kornia.filters as k
from torchvision.utils import save_image
from util import *
import scipy
import gradio as gr
import PIL
from torchvision import transforms
device = 'cpu' #@param ['cuda', 'cpu']
generator = Generator(256, 512, 8, channel_multiplier=2).eval().to(device)
truncation = 0.7
def display_image(image, size=None, mode='nearest', unnorm=False, title=''):
# image is [3,h,w] or [1,3,h,w] tensor [0,1]
if image.is_cuda:
image = image.cpu()
if size is not None and image.size(-1) != size:
image = F.interpolate(image, size=(size,size), mode=mode)
if image.dim() == 4:
image = image[0]
image = ((image.clamp(-1,1)+1)/2).permute(1, 2, 0).detach().numpy()
return image
#mean_latentland = load_model(generator, 'landscape.pt')
#mean_latentface = load_model(generator, 'face.pt')
#mean_latentchurch = load_model(generator, 'church.pt')
def inferece(num, seed):
mean_latent = load_model(generator, 'landscape.pt')
num_im = int(num)
random_seed = int(seed)
plt.rcParams['figure.dpi'] = 300
# pad determines how much of an image is involve in the blending
pad = 512//4
all_im = []
random_state = np.random.RandomState(random_seed)
# latent smoothing
with torch.no_grad():
z = random_state.randn(num_im, 512).astype(np.float32)
z = scipy.ndimage.gaussian_filter(z, [.7, 0], mode='wrap')
z /= np.sqrt(np.mean(np.square(z)))
z = torch.from_numpy(z).to(device)
source = generator.get_latent(z, truncation=truncation, mean_latent=mean_latent)
# merge images 2 at a time
for i in range(num_im-1):
source1 = index_layers(source, i)
source2 = index_layers(source, i+1)
all_im.append(generator.merge_extension(source1, source2))
# display intermediate generations
# for i in all_im:
# display_image(i)
b,c,h,w = all_im[0].shape
panorama_im = torch.zeros(b,c,h,512+(num_im-2)*256)
# We created a series of 2-blended images which we can overlay to form a large panorama
# add first image
coord = 256+pad
panorama_im[..., :coord] = all_im[0][..., :coord]
for im in all_im[1:]:
panorama_im[..., coord:coord+512-2*pad] = im[..., pad:-pad]
coord += 512-2*pad
panorama_im[..., coord:] = all_im[-1][..., 512-pad:]
img = display_image(panorama_im)
return img
title = "SOAT"
description = "Gradio demo for SOAT Panorama Generaton for landscapes. Generate a panorama using a pretrained stylegan by stitching intermediate activations. To use it, simply add the number of images and random seed number . Read more at the links below."
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2111.01619' target='_blank'>StyleGAN of All Trades: Image Manipulation with Only Pretrained StyleGAN</a> | <a href='https://github.com/mchong6/SOAT' target='_blank'>Github Repo</a></p>"
gr.Interface(
inferece,
[gr.inputs.Number(default=5, label="Number of Images")
,gr.inputs.Number(default=90, label="Random Seed")
],
gr.outputs.Image(type="numpy", label="Output"),
title=title,
description=description,
article=article, theme="huggingface",enable_queue=True).launch(debug=True)