Spaces:
Build error
Build error
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | |
"""Wrap the generator to render a sequence of images""" | |
import torch | |
import torch.nn.functional as F | |
import numpy as np | |
from torch import random | |
import tqdm | |
import copy | |
import trimesh | |
class Renderer(object): | |
def __init__(self, generator, discriminator=None, program=None): | |
self.generator = generator | |
self.discriminator = discriminator | |
self.sample_tmp = 0.65 | |
self.program = program | |
self.seed = 0 | |
if (program is not None) and (len(program.split(':')) == 2): | |
from training.dataset import ImageFolderDataset | |
self.image_data = ImageFolderDataset(program.split(':')[1]) | |
self.program = program.split(':')[0] | |
else: | |
self.image_data = None | |
def set_random_seed(self, seed): | |
self.seed = seed | |
torch.manual_seed(seed) | |
np.random.seed(seed) | |
def __call__(self, *args, **kwargs): | |
self.generator.eval() # eval mode... | |
if self.program is None: | |
if hasattr(self.generator, 'get_final_output'): | |
return self.generator.get_final_output(*args, **kwargs) | |
return self.generator(*args, **kwargs) | |
if self.image_data is not None: | |
batch_size = 1 | |
indices = (np.random.rand(batch_size) * len(self.image_data)).tolist() | |
rimages = np.stack([self.image_data._load_raw_image(int(i)) for i in indices], 0) | |
rimages = torch.from_numpy(rimages).float().to(kwargs['z'].device) / 127.5 - 1 | |
kwargs['img'] = rimages | |
outputs = getattr(self, f"render_{self.program}")(*args, **kwargs) | |
if self.image_data is not None: | |
imgs = outputs if not isinstance(outputs, tuple) else outputs[0] | |
size = imgs[0].size(-1) | |
rimg = F.interpolate(rimages, (size, size), mode='bicubic', align_corners=False) | |
imgs = [torch.cat([img, rimg], 0) for img in imgs] | |
outputs = imgs if not isinstance(outputs, tuple) else (imgs, outputs[1]) | |
return outputs | |
def get_additional_params(self, ws, t=0): | |
gen = self.generator.synthesis | |
batch_size = ws.size(0) | |
kwargs = {} | |
if not hasattr(gen, 'get_latent_codes'): | |
return kwargs | |
s_val, t_val, r_val = [[0, 0, 0]], [[0.5, 0.5, 0.5]], [0.] | |
# kwargs["transformations"] = gen.get_transformations(batch_size=batch_size, mode=[s_val, t_val, r_val], device=ws.device) | |
# kwargs["bg_rotation"] = gen.get_bg_rotation(batch_size, device=ws.device) | |
# kwargs["light_dir"] = gen.get_light_dir(batch_size, device=ws.device) | |
kwargs["latent_codes"] = gen.get_latent_codes(batch_size, tmp=self.sample_tmp, device=ws.device) | |
kwargs["camera_matrices"] = self.get_camera_traj(t, ws.size(0), device=ws.device) | |
return kwargs | |
def get_camera_traj(self, t, batch_size=1, traj_type='pigan', device='cpu'): | |
gen = self.generator.synthesis | |
if traj_type == 'pigan': | |
range_u, range_v = gen.C.range_u, gen.C.range_v | |
pitch = 0.2 * np.cos(t * 2 * np.pi) + np.pi/2 | |
yaw = 0.4 * np.sin(t * 2 * np.pi) | |
u = (yaw - range_u[0]) / (range_u[1] - range_u[0]) | |
v = (pitch - range_v[0]) / (range_v[1] - range_v[0]) | |
cam = gen.get_camera(batch_size=batch_size, mode=[u, v, 0.5], device=device) | |
else: | |
raise NotImplementedError | |
return cam | |
def render_rotation_camera(self, *args, **kwargs): | |
batch_size, n_steps = 2, kwargs["n_steps"] | |
gen = self.generator.synthesis | |
if 'img' not in kwargs: | |
ws = self.generator.mapping(*args, **kwargs) | |
else: | |
ws, _ = self.generator.encoder(kwargs['img']) | |
# ws = ws.repeat(batch_size, 1, 1) | |
# kwargs["not_render_background"] = True | |
if hasattr(gen, 'get_latent_codes'): | |
kwargs["latent_codes"] = gen.get_latent_codes(batch_size, tmp=self.sample_tmp, device=ws.device) | |
kwargs.pop('img', None) | |
out = [] | |
cameras = [] | |
relatve_range_u = kwargs['relative_range_u'] | |
u_samples = np.linspace(relatve_range_u[0], relatve_range_u[1], n_steps) | |
for step in tqdm.tqdm(range(n_steps)): | |
# Set Camera | |
u = u_samples[step] | |
kwargs["camera_matrices"] = gen.get_camera(batch_size=batch_size, mode=[u, 0.5, 0.5], device=ws.device) | |
cameras.append(gen.get_camera(batch_size=batch_size, mode=[u, 0.5, 0.5], device=ws.device)) | |
with torch.no_grad(): | |
out_i = gen(ws, **kwargs) | |
if isinstance(out_i, dict): | |
out_i = out_i['img'] | |
out.append(out_i) | |
if 'return_cameras' in kwargs and kwargs["return_cameras"]: | |
return out, cameras | |
else: | |
return out | |
def render_rotation_camera3(self, styles=None, *args, **kwargs): | |
gen = self.generator.synthesis | |
n_steps = 36 # 120 | |
if styles is None: | |
batch_size = 2 | |
if 'img' not in kwargs: | |
ws = self.generator.mapping(*args, **kwargs) | |
else: | |
ws = self.generator.encoder(kwargs['img'])['ws'] | |
# ws = ws.repeat(batch_size, 1, 1) | |
else: | |
ws = styles | |
batch_size = ws.size(0) | |
# kwargs["not_render_background"] = True | |
# Get Random codes and bg rotation | |
self.sample_tmp = 0.72 | |
if hasattr(gen, 'get_latent_codes'): | |
kwargs["latent_codes"] = gen.get_latent_codes(batch_size, tmp=self.sample_tmp, device=ws.device) | |
kwargs.pop('img', None) | |
# if getattr(gen, "use_noise", False): | |
# from dnnlib.geometry import extract_geometry | |
# kwargs['meshes'] = {} | |
# low_res, high_res = gen.resolution_vol, gen.img_resolution | |
# res = low_res * 2 | |
# while res <= high_res: | |
# kwargs['meshes'][res] = [trimesh.Trimesh(*extract_geometry(gen, ws, resolution=res, threshold=30.))] | |
# kwargs['meshes'][res] += [ | |
# torch.randn(len(kwargs['meshes'][res][0].vertices), | |
# 2, device=ws.device)[kwargs['meshes'][res][0].faces]] | |
# res = res * 2 | |
# if getattr(gen, "use_noise", False): | |
# kwargs['voxel_noise'] = gen.get_voxel_field(styles=ws, n_vols=2048, return_noise=True, sphere_noise=True) | |
# if getattr(gen, "use_voxel_noise", False): | |
# kwargs['voxel_noise'] = gen.get_voxel_field(styles=ws, n_vols=128, return_noise=True) | |
kwargs['noise_mode'] = 'const' | |
out = [] | |
tspace = np.linspace(0, 1, n_steps) | |
range_u, range_v = gen.C.range_u, gen.C.range_v | |
for step in tqdm.tqdm(range(n_steps)): | |
t = tspace[step] | |
pitch = 0.2 * np.cos(t * 2 * np.pi) + np.pi/2 | |
yaw = 0.4 * np.sin(t * 2 * np.pi) | |
u = (yaw - range_u[0]) / (range_u[1] - range_u[0]) | |
v = (pitch - range_v[0]) / (range_v[1] - range_v[0]) | |
kwargs["camera_matrices"] = gen.get_camera( | |
batch_size=batch_size, mode=[u, v, t], device=ws.device) | |
with torch.no_grad(): | |
out_i = gen(ws, **kwargs) | |
if isinstance(out_i, dict): | |
out_i = out_i['img'] | |
out.append(out_i) | |
return out | |
def render_rotation_both(self, *args, **kwargs): | |
gen = self.generator.synthesis | |
batch_size, n_steps = 1, 36 | |
if 'img' not in kwargs: | |
ws = self.generator.mapping(*args, **kwargs) | |
else: | |
ws, _ = self.generator.encoder(kwargs['img']) | |
ws = ws.repeat(batch_size, 1, 1) | |
# kwargs["not_render_background"] = True | |
# Get Random codes and bg rotation | |
kwargs["latent_codes"] = gen.get_latent_codes(batch_size, tmp=self.sample_tmp, device=ws.device) | |
kwargs.pop('img', None) | |
out = [] | |
tspace = np.linspace(0, 1, n_steps) | |
range_u, range_v = gen.C.range_u, gen.C.range_v | |
for step in tqdm.tqdm(range(n_steps)): | |
t = tspace[step] | |
pitch = 0.2 * np.cos(t * 2 * np.pi) + np.pi/2 | |
yaw = 0.4 * np.sin(t * 2 * np.pi) | |
u = (yaw - range_u[0]) / (range_u[1] - range_u[0]) | |
v = (pitch - range_v[0]) / (range_v[1] - range_v[0]) | |
kwargs["camera_matrices"] = gen.get_camera( | |
batch_size=batch_size, mode=[u, v, 0.5], device=ws.device) | |
with torch.no_grad(): | |
out_i = gen(ws, **kwargs) | |
if isinstance(out_i, dict): | |
out_i = out_i['img'] | |
kwargs_n = copy.deepcopy(kwargs) | |
kwargs_n.update({'render_option': 'early,no_background,up64,depth,normal'}) | |
out_n = gen(ws, **kwargs_n) | |
out_n = F.interpolate(out_n, | |
size=(out_i.size(-1), out_i.size(-1)), | |
mode='bicubic', align_corners=True) | |
out_i = torch.cat([out_i, out_n], 0) | |
out.append(out_i) | |
return out | |
def render_rotation_grid(self, styles=None, return_cameras=False, *args, **kwargs): | |
gen = self.generator.synthesis | |
if styles is None: | |
batch_size = 1 | |
ws = self.generator.mapping(*args, **kwargs) | |
ws = ws.repeat(batch_size, 1, 1) | |
else: | |
ws = styles | |
batch_size = ws.size(0) | |
kwargs["latent_codes"] = gen.get_latent_codes(batch_size, tmp=self.sample_tmp, device=ws.device) | |
kwargs.pop('img', None) | |
if getattr(gen, "use_voxel_noise", False): | |
kwargs['voxel_noise'] = gen.get_voxel_field(styles=ws, n_vols=128, return_noise=True) | |
out = [] | |
cameras = [] | |
range_u, range_v = gen.C.range_u, gen.C.range_v | |
a_steps, b_steps = 6, 3 | |
aspace = np.linspace(-0.4, 0.4, a_steps) | |
bspace = np.linspace(-0.2, 0.2, b_steps) * -1 | |
for b in tqdm.tqdm(range(b_steps)): | |
for a in range(a_steps): | |
t_a = aspace[a] | |
t_b = bspace[b] | |
camera_mat = gen.camera_matrix.repeat(batch_size, 1, 1).to(ws.device) | |
loc_x = np.cos(t_b) * np.cos(t_a) | |
loc_y = np.cos(t_b) * np.sin(t_a) | |
loc_z = np.sin(t_b) | |
loc = torch.tensor([[loc_x, loc_y, loc_z]], dtype=torch.float32).to(ws.device) | |
from dnnlib.camera import look_at | |
R = look_at(loc) | |
RT = torch.eye(4).reshape(1, 4, 4).repeat(batch_size, 1, 1) | |
RT[:, :3, :3] = R | |
RT[:, :3, -1] = loc | |
world_mat = RT.to(ws.device) | |
#kwargs["camera_matrices"] = gen.get_camera( | |
# batch_size=batch_size, mode=[u, v, 0.5], device=ws.device) | |
kwargs["camera_matrices"] = (camera_mat, world_mat, "random", None) | |
with torch.no_grad(): | |
out_i = gen(ws, **kwargs) | |
if isinstance(out_i, dict): | |
out_i = out_i['img'] | |
# kwargs_n = copy.deepcopy(kwargs) | |
# kwargs_n.update({'render_option': 'early,no_background,up64,depth,normal'}) | |
# out_n = gen(ws, **kwargs_n) | |
# out_n = F.interpolate(out_n, | |
# size=(out_i.size(-1), out_i.size(-1)), | |
# mode='bicubic', align_corners=True) | |
# out_i = torch.cat([out_i, out_n], 0) | |
out.append(out_i) | |
if return_cameras: | |
return out, cameras | |
else: | |
return out | |
def render_rotation_camera_grid(self, *args, **kwargs): | |
batch_size, n_steps = 1, 60 | |
gen = self.generator.synthesis | |
bbox_generator = self.generator.synthesis.boundingbox_generator | |
ws = self.generator.mapping(*args, **kwargs) | |
ws = ws.repeat(batch_size, 1, 1) | |
# Get Random codes and bg rotation | |
kwargs["latent_codes"] = gen.get_latent_codes(batch_size, tmp=self.sample_tmp, device=ws.device) | |
del kwargs['render_option'] | |
out = [] | |
for v in [0.15, 0.5, 1.05]: | |
for step in tqdm.tqdm(range(n_steps)): | |
# Set Camera | |
u = step * 1.0 / (n_steps - 1) - 1.0 | |
kwargs["camera_matrices"] = gen.get_camera(batch_size=batch_size, mode=[u, v, 0.5], device=ws.device) | |
with torch.no_grad(): | |
out_i = gen(ws, render_option=None, **kwargs) | |
if isinstance(out_i, dict): | |
out_i = out_i['img'] | |
# option_n = 'early,no_background,up64,depth,direct_depth' | |
# option_n = 'early,up128,no_background,depth,normal' | |
# out_n = gen(ws, render_option=option_n, **kwargs) | |
# out_n = F.interpolate(out_n, | |
# size=(out_i.size(-1), out_i.size(-1)), | |
# mode='bicubic', align_corners=True) | |
# out_i = torch.cat([out_i, out_n], 0) | |
out.append(out_i) | |
# out += out[::-1] | |
return out |