Spaces:
Runtime error
Runtime error
import os | |
import pickle | |
import sys | |
sys.path.insert(0, 'stylegan_xl') | |
import imageio | |
import numpy as np | |
import scipy.interpolate | |
import torch | |
from tqdm import tqdm | |
import gradio as gr | |
from huggingface_hub import hf_hub_download | |
def layout_grid(img, grid_w=None, grid_h=1, float_to_uint8=True, chw_to_hwc=True, to_numpy=True): | |
batch_size, channels, img_h, img_w = img.shape | |
if grid_w is None: | |
grid_w = batch_size // grid_h | |
assert batch_size == grid_w * grid_h | |
if float_to_uint8: | |
img = (img * 127.5 + 128).clamp(0, 255).to(torch.uint8) | |
img = img.reshape(grid_h, grid_w, channels, img_h, img_w) | |
img = img.permute(2, 0, 3, 1, 4) | |
img = img.reshape(channels, grid_h * img_h, grid_w * img_w) | |
if chw_to_hwc: | |
img = img.permute(1, 2, 0) | |
if to_numpy: | |
img = img.cpu().numpy() | |
return img | |
network_pkl='braingan-400.pkl' | |
with open(network_pkl, 'rb') as f: | |
G = pickle.load(f)['G_ema'] | |
def predict(Seed,choices): | |
device = torch.device('cuda') | |
G.eval() | |
G.to(device) | |
shuffle_seed=None | |
w_frames=60*4 | |
kind='cubic' | |
num_keyframes=None | |
wraps=2 | |
psi=1 | |
device=torch.device('cuda') | |
if choices=='4x2': | |
grid_w = 4 | |
grid_h = 2 | |
s1=Seed | |
seeds=(np.arange(s1-16,s1)).tolist() | |
if choices=='2x1': | |
grid_w = 2 | |
grid_h = 1 | |
s1=Seed | |
seeds=(np.arange(s1-4,s1)).tolist() | |
mp4='ex.mp4' | |
truncation_psi=1 | |
num_keyframes=None | |
if num_keyframes is None: | |
if len(seeds) % (grid_w*grid_h) != 0: | |
raise ValueError('Number of input seeds must be divisible by grid W*H') | |
num_keyframes = len(seeds) // (grid_w*grid_h) | |
all_seeds = np.zeros(num_keyframes*grid_h*grid_w, dtype=np.int64) | |
for idx in range(num_keyframes*grid_h*grid_w): | |
all_seeds[idx] = seeds[idx % len(seeds)] | |
if shuffle_seed is not None: | |
rng = np.random.RandomState(seed=shuffle_seed) | |
rng.shuffle(all_seeds) | |
zs = torch.from_numpy(np.stack([np.random.RandomState(seed).randn(G.z_dim) for seed in all_seeds])).to(device) | |
ws = G.mapping(z=zs, c=None, truncation_psi=psi) | |
_ = G.synthesis(ws[:1]) # warm up | |
ws = ws.reshape(grid_h, grid_w, num_keyframes, *ws.shape[1:]) | |
# Interpolation. | |
grid = [] | |
for yi in range(grid_h): | |
row = [] | |
for xi in range(grid_w): | |
x = np.arange(-num_keyframes * wraps, num_keyframes * (wraps + 1)) | |
y = np.tile(ws[yi][xi].cpu().numpy(), [wraps * 2 + 1, 1, 1]) | |
interp = scipy.interpolate.interp1d(x, y, kind=kind, axis=0) | |
row.append(interp) | |
grid.append(row) | |
# Render video. | |
video_out = imageio.get_writer(mp4, mode='I', fps=60, codec='libx264') | |
for frame_idx in tqdm(range(num_keyframes * w_frames)): | |
imgs = [] | |
for yi in range(grid_h): | |
for xi in range(grid_w): | |
interp = grid[yi][xi] | |
w = torch.from_numpy(interp(frame_idx / w_frames)).to(device) | |
img = G.synthesis(ws=w.unsqueeze(0), noise_mode='const')[0] | |
imgs.append(img) | |
video_out.append_data(layout_grid(torch.stack(imgs), grid_w=grid_w, grid_h=grid_h)) | |
video_out.close() | |
return 'ex.mp4' | |
choices=['4x2','2x1'] | |
interface=gr.Interface(fn=predict, title="Brain MR Image Generation with StyleGAN-2", | |
description = "", | |
article = "Author: S.Serdar Helli", | |
inputs=[gr.inputs.Slider( minimum=16, maximum=2**10,label='Seed'),gr.inputs.Radio( choices=choices, default='4x2',label='Image Grid')], | |
outputs=gr.outputs.Video(label='Video')) | |
interface.launch(debug=True) |