File size: 3,624 Bytes
b2fd97c
e2cf2b0
b2fd97c
 
93eccf3
 
b2fd97c
93eccf3
 
4245274
93eccf3
e2cf2b0
 
 
 
 
 
 
c1a6745
e2cf2b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c1a6745
e2cf2b0
9a66c24
e2cf2b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import os
import pickle
import sys

import subprocess
import tarfile

import subprocess


    
import imageio
import numpy as np
import scipy.interpolate
import torch
from tqdm import tqdm
import gradio as gr 
from huggingface_hub import hf_hub_download


def layout_grid(img, grid_w=None, grid_h=1, float_to_uint8=True, chw_to_hwc=True, to_numpy=True):
    batch_size, channels, img_h, img_w = img.shape
    if grid_w is None:
        grid_w = batch_size // grid_h
    assert batch_size == grid_w * grid_h
    if float_to_uint8:
        img = (img * 127.5 + 128).clamp(0, 255).to(torch.uint8)
    img = img.reshape(grid_h, grid_w, channels, img_h, img_w)
    img = img.permute(2, 0, 3, 1, 4)
    img = img.reshape(channels, grid_h * img_h, grid_w * img_w)
    if chw_to_hwc:
        img = img.permute(1, 2, 0)
    if to_numpy:
        img = img.cpu().numpy()
    return img




network_pkl='braingan-400.pkl'
with open(network_pkl, 'rb') as f:
    G = pickle.load(f)['G_ema'] 

def predict(Seed,choices):
  device = torch.device('cuda')
  G.eval()
  G.to(device)
  shuffle_seed=None
  w_frames=60*4
  kind='cubic' 
  num_keyframes=None
  wraps=2
  psi=1 
  device=torch.device('cuda')

  
  if choices=='4x2':
    grid_w = 4
    grid_h = 2
    s1=Seed
    seeds=(np.arange(s1-16,s1)).tolist()
  if choices=='2x1':
    grid_w = 2
    grid_h = 1
    s1=Seed
    seeds=(np.arange(s1-4,s1)).tolist()


  mp4='ex.mp4'
  truncation_psi=1
  num_keyframes=None


  if num_keyframes is None:
      if len(seeds) % (grid_w*grid_h) != 0:
          raise ValueError('Number of input seeds must be divisible by grid W*H')
      num_keyframes = len(seeds) // (grid_w*grid_h)

  all_seeds = np.zeros(num_keyframes*grid_h*grid_w, dtype=np.int64)
  for idx in range(num_keyframes*grid_h*grid_w):
      all_seeds[idx] = seeds[idx % len(seeds)]

  if shuffle_seed is not None:
      rng = np.random.RandomState(seed=shuffle_seed)
      rng.shuffle(all_seeds)

  zs = torch.from_numpy(np.stack([np.random.RandomState(seed).randn(G.z_dim) for seed in all_seeds])).to(device)
  ws = G.mapping(z=zs, c=None, truncation_psi=psi)
  _ = G.synthesis(ws[:1]) # warm up
  ws = ws.reshape(grid_h, grid_w, num_keyframes, *ws.shape[1:])

  # Interpolation.
  grid = []
  for yi in range(grid_h):
      row = []
      for xi in range(grid_w):
          x = np.arange(-num_keyframes * wraps, num_keyframes * (wraps + 1))
          y = np.tile(ws[yi][xi].cpu().numpy(), [wraps * 2 + 1, 1, 1])
          interp = scipy.interpolate.interp1d(x, y, kind=kind, axis=0)
          row.append(interp)
      grid.append(row)

  # Render video.
  video_out = imageio.get_writer(mp4, mode='I', fps=60, codec='libx264')
  for frame_idx in tqdm(range(num_keyframes * w_frames)):
      imgs = []
      for yi in range(grid_h):
          for xi in range(grid_w):
              interp = grid[yi][xi]
              w = torch.from_numpy(interp(frame_idx / w_frames)).to(device)
              img = G.synthesis(ws=w.unsqueeze(0), noise_mode='const')[0]
              imgs.append(img)
      video_out.append_data(layout_grid(torch.stack(imgs), grid_w=grid_w, grid_h=grid_h))
  video_out.close()
  return 'ex.mp4'



choices=['4x2','2x1']
interface=gr.Interface(fn=predict, title="Brain MR Image Generation with StyleGAN-2",
                       description = "",
                       article = "Author: S.Serdar Helli",
                       inputs=[gr.inputs.Slider( minimum=16, maximum=2**10,label='Seed'),gr.inputs.Radio( choices=choices,  default='4x2',label='Image Grid')],
                       outputs=gr.outputs.Video(label='Video'))


interface.launch(debug=True)