|
import os |
|
os.system("git clone https://github.com/mchong6/SOAT.git") |
|
import sys |
|
sys.path.append("SOAT") |
|
import os |
|
import torch |
|
import torchvision |
|
from torch import nn |
|
import numpy as np |
|
import torch.backends.cudnn as cudnn |
|
cudnn.benchmark = True |
|
|
|
import math |
|
import matplotlib.pyplot as plt |
|
import torch.nn.functional as F |
|
from model import * |
|
from tqdm import tqdm as tqdm |
|
import pickle |
|
from copy import deepcopy |
|
import warnings |
|
warnings.filterwarnings("ignore", category=UserWarning) |
|
import kornia.filters as k |
|
from torchvision.utils import save_image |
|
from util import * |
|
import scipy |
|
|
|
import gradio as gr |
|
|
|
import PIL |
|
|
|
from torchvision import transforms |
|
|
|
device = 'cpu' |
|
|
|
generator = Generator(256, 512, 8, channel_multiplier=2).eval().to(device) |
|
truncation = 0.7 |
|
|
|
def display_image(image, size=None, mode='nearest', unnorm=False, title=''): |
|
|
|
if image.is_cuda: |
|
image = image.cpu() |
|
if size is not None and image.size(-1) != size: |
|
image = F.interpolate(image, size=(size,size), mode=mode) |
|
if image.dim() == 4: |
|
image = image[0] |
|
image = ((image.clamp(-1,1)+1)/2).permute(1, 2, 0).detach().numpy() |
|
return image |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def inferece(num, seed): |
|
mean_latent = load_model(generator, 'landscape.pt') |
|
|
|
num_im = int(num) |
|
random_seed = int(seed) |
|
|
|
plt.rcParams['figure.dpi'] = 300 |
|
|
|
|
|
|
|
pad = 512//4 |
|
|
|
all_im = [] |
|
|
|
random_state = np.random.RandomState(random_seed) |
|
|
|
|
|
with torch.no_grad(): |
|
z = random_state.randn(num_im, 512).astype(np.float32) |
|
z = scipy.ndimage.gaussian_filter(z, [.7, 0], mode='wrap') |
|
z /= np.sqrt(np.mean(np.square(z))) |
|
z = torch.from_numpy(z).to(device) |
|
|
|
source = generator.get_latent(z, truncation=truncation, mean_latent=mean_latent) |
|
|
|
|
|
for i in range(num_im-1): |
|
source1 = index_layers(source, i) |
|
source2 = index_layers(source, i+1) |
|
all_im.append(generator.merge_extension(source1, source2)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
b,c,h,w = all_im[0].shape |
|
panorama_im = torch.zeros(b,c,h,512+(num_im-2)*256) |
|
|
|
|
|
|
|
coord = 256+pad |
|
panorama_im[..., :coord] = all_im[0][..., :coord] |
|
|
|
for im in all_im[1:]: |
|
panorama_im[..., coord:coord+512-2*pad] = im[..., pad:-pad] |
|
coord += 512-2*pad |
|
panorama_im[..., coord:] = all_im[-1][..., 512-pad:] |
|
|
|
img = display_image(panorama_im) |
|
return img |
|
|
|
title = "SOAT" |
|
description = "Gradio demo for SOAT Panorama Generaton for landscapes. Generate a panorama using a pretrained stylegan by stitching intermediate activations. To use it, simply add the number of images and random seed number . Read more at the links below." |
|
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2111.01619' target='_blank'>StyleGAN of All Trades: Image Manipulation with Only Pretrained StyleGAN</a> | <a href='https://github.com/mchong6/SOAT' target='_blank'>Github Repo</a></p>" |
|
|
|
gr.Interface( |
|
inferece, |
|
[gr.inputs.Number(default=5, label="Number of Images") |
|
,gr.inputs.Number(default=90, label="Random Seed") |
|
], |
|
gr.outputs.Image(type="numpy", label="Output"), |
|
title=title, |
|
description=description, |
|
article=article, theme="huggingface",enable_queue=True).launch(debug=True) |