Spaces:
Runtime error
Runtime error
import sys | |
import os | |
os.system("git clone https://github.com/dunbar12138/pix2pix3D.git") | |
sys.path.append("pix2pix3D") | |
from typing import List, Optional, Tuple, Union | |
import dnnlib | |
import numpy as np | |
import PIL.Image | |
import torch | |
from tqdm import tqdm | |
import legacy | |
from camera_utils import LookAtPoseSampler | |
from huggingface_hub import hf_hub_download | |
from matplotlib import pyplot as plt | |
from pathlib import Path | |
import gradio as gr | |
from training.utils import color_mask, color_list | |
import plotly.graph_objects as go | |
from tqdm import tqdm | |
import imageio | |
import trimesh | |
import mcubes | |
import copy | |
import pickle | |
import numpy as np | |
import torch | |
import dnnlib | |
from torch_utils import misc | |
from legacy import * | |
import io | |
os.environ["PYOPENGL_PLATFORM"] = "egl" | |
def get_sigma_field_np(nerf, styles, resolution=512, block_resolution=64): | |
# return numpy array of forwarded sigma value | |
# bound = (nerf.rendering_kwargs['ray_end'] - nerf.rendering_kwargs['ray_start']) * 0.5 | |
bound = nerf.rendering_kwargs['box_warp'] * 0.5 | |
X = torch.linspace(-bound, bound, resolution).split(block_resolution) | |
sigma_np = np.zeros([resolution, resolution, resolution], dtype=np.float32) | |
for xi, xs in enumerate(X): | |
for yi, ys in enumerate(X): | |
for zi, zs in enumerate(X): | |
xx, yy, zz = torch.meshgrid(xs, ys, zs) | |
pts = torch.stack([xx, yy, zz], dim=-1).unsqueeze(0).to(styles.device) # B, H, H, H, C | |
block_shape = [1, len(xs), len(ys), len(zs)] | |
out = nerf.sample_mixed(pts.reshape(1,-1,3), None, ws=styles, noise_mode='const') | |
feat_out, sigma_out = out['rgb'], out['sigma'] | |
sigma_np[xi * block_resolution: xi * block_resolution + len(xs), \ | |
yi * block_resolution: yi * block_resolution + len(ys), \ | |
zi * block_resolution: zi * block_resolution + len(zs)] = sigma_out.reshape(block_shape[1:]).detach().cpu().numpy() | |
# print(feat_out.shape) | |
return sigma_np, bound | |
def extract_geometry(nerf, styles, resolution, threshold): | |
# print('threshold: {}'.format(threshold)) | |
u, bound = get_sigma_field_np(nerf, styles, resolution) | |
vertices, faces = mcubes.marching_cubes(u, threshold) | |
# vertices, faces, normals, values = skimage.measure.marching_cubes( | |
# u, level=10 | |
# ) | |
b_min_np = np.array([-bound, -bound, -bound]) | |
b_max_np = np.array([ bound, bound, bound]) | |
vertices = vertices / (resolution - 1.0) * (b_max_np - b_min_np)[None, :] + b_min_np[None, :] | |
return vertices.astype('float32'), faces | |
def render_video(G, ws, intrinsics, num_frames = 120, pitch_range = 0.25, yaw_range = 0.35, neural_rendering_resolution = 128, device='cuda'): | |
frames, frames_label = [], [] | |
for frame_idx in tqdm(range(num_frames)): | |
cam2world_pose = LookAtPoseSampler.sample(3.14/2 + yaw_range * np.sin(2 * 3.14 * frame_idx / num_frames), | |
3.14/2 -0.05 + pitch_range * np.cos(2 * 3.14 * frame_idx / num_frames), | |
torch.tensor(G.rendering_kwargs['avg_camera_pivot'], device=device), radius=G.rendering_kwargs['avg_camera_radius'], device=device) | |
pose = torch.cat([cam2world_pose.reshape(-1, 16), intrinsics.reshape(-1, 9)], 1) | |
with torch.no_grad(): | |
# out = G(z, pose, {'mask': batch['mask'].unsqueeze(0).to(device), 'pose': torch.tensor(batch['pose']).unsqueeze(0).to(device)}) | |
out = G.synthesis(ws, pose, noise_mode='const', neural_rendering_resolution=neural_rendering_resolution) | |
frames.append(((out['image'].cpu().numpy()[0] + 1) * 127.5).clip(0, 255).astype(np.uint8).transpose(1, 2, 0)) | |
frames_label.append(color_mask(torch.argmax(out['semantic'], dim=1).cpu().numpy()[0]).astype(np.uint8)) | |
return frames, frames_label | |
def return_plot_go(mesh_trimesh): | |
x=np.asarray(mesh_trimesh.vertices).T[0] | |
y=np.asarray(mesh_trimesh.vertices).T[1] | |
z=np.asarray(mesh_trimesh.vertices).T[2] | |
i=np.asarray(mesh_trimesh.faces).T[0] | |
j=np.asarray(mesh_trimesh.faces).T[1] | |
k=np.asarray(mesh_trimesh.faces).T[2] | |
fig = go.Figure(go.Mesh3d(x=x, y=y, z=z, | |
i=i, j=j, k=k, | |
vertexcolor=np.asarray(mesh_trimesh.visual.vertex_colors) , | |
lighting=dict(ambient=0.5, | |
diffuse=1, | |
fresnel=4, | |
specular=0.5, | |
roughness=0.05, | |
facenormalsepsilon=0, | |
vertexnormalsepsilon=0), | |
lightposition=dict(x=100, | |
y=100, | |
z=1000))) | |
return fig | |
network_cat=hf_hub_download("SerdarHelli/pix2pix3d_seg2cat", filename="pix2pix3d_seg2cat.pkl",revision="main") | |
models={"seg2cat":network_cat | |
} | |
device='cuda' if torch.cuda.is_available() else 'cpu' | |
outdir="./" | |
class CPU_Unpickler(pickle.Unpickler): | |
def find_class(self, module, name): | |
if module == 'torch.storage' and name == '_load_from_bytes': | |
return lambda b: torch.load(io.BytesIO(b), map_location='cpu') | |
return super().find_class(module, name) | |
def load_network_pkl_cpu(f, force_fp16=False): | |
data = CPU_Unpickler(f).load() | |
# Legacy TensorFlow pickle => convert. | |
if isinstance(data, tuple) and len(data) == 3 and all(isinstance(net, _TFNetworkStub) for net in data): | |
tf_G, tf_D, tf_Gs = data | |
G = convert_tf_generator(tf_G) | |
D = convert_tf_discriminator(tf_D) | |
G_ema = convert_tf_generator(tf_Gs) | |
data = dict(G=G, D=D, G_ema=G_ema) | |
# Add missing fields. | |
if 'training_set_kwargs' not in data: | |
data['training_set_kwargs'] = None | |
if 'augment_pipe' not in data: | |
data['augment_pipe'] = None | |
# Validate contents. | |
assert isinstance(data['G'], torch.nn.Module) | |
assert isinstance(data['D'], torch.nn.Module) | |
assert isinstance(data['G_ema'], torch.nn.Module) | |
assert isinstance(data['training_set_kwargs'], (dict, type(None))) | |
assert isinstance(data['augment_pipe'], (torch.nn.Module, type(None))) | |
# Force FP16. | |
if force_fp16: | |
for key in ['G', 'D', 'G_ema']: | |
old = data[key] | |
kwargs = copy.deepcopy(old.init_kwargs) | |
fp16_kwargs = kwargs.get('synthesis_kwargs', kwargs) | |
fp16_kwargs.num_fp16_res = 4 | |
fp16_kwargs.conv_clamp = 256 | |
if kwargs != old.init_kwargs: | |
new = type(old)(**kwargs).eval().requires_grad_(False) | |
misc.copy_params_and_buffers(old, new, require_all=True) | |
data[key] = new | |
return data | |
color_list = [[255, 255, 255], [204, 0, 0], [76, 153, 0], [204, 204, 0], [51, 51, 255], [204, 0, 204], [0, 255, 255], [255, 204, 204], [102, 51, 0], [255, 0, 0], [102, 204, 0], [255, 255, 0], [0, 0, 153], [0, 0, 204], [255, 51, 153], [0, 204, 204], [0, 51, 0], [255, 153, 51], [0, 204, 0]] | |
def colormap2labelmap(color_img): | |
im_base = np.zeros((color_img.shape[0], color_img.shape[1])) | |
for idx, color in enumerate(color_list): | |
k1=((color_img == np.asarray(color))[:,:,0])*1 | |
k2=((color_img == np.asarray(color))[:,:,1])*1 | |
k3=((color_img == np.asarray(color))[:,:,2])*1 | |
k=((k1*k2*k3)==1) | |
im_base[k] = idx | |
return im_base | |
def checklabelmap(img): | |
labels=np.unique(img) | |
for idx,label in enumerate(labels): | |
img[img==label]=idx | |
return img | |
def get_all(cfg,input,truncation_psi,mesh_resolution,random_seed,fps,num_frames): | |
network=models[cfg] | |
if device=="cpu": | |
with dnnlib.util.open_url(network) as f: | |
G = load_network_pkl_cpu(f)['G_ema'].eval().to(device) | |
else: | |
with dnnlib.util.open_url(network) as f: | |
G = legacy.load_network_pkl(f)['G_ema'].eval().to(device) | |
if cfg == 'seg2cat' or cfg == 'seg2face': | |
neural_rendering_resolution = 128 | |
data_type = 'seg' | |
# Initialize pose sampler. | |
forward_cam2world_pose = LookAtPoseSampler.sample(3.14/2, 3.14/2, torch.tensor(G.rendering_kwargs['avg_camera_pivot'], device=device), | |
radius=G.rendering_kwargs['avg_camera_radius'], device=device) | |
focal_length = 4.2647 # shapenet has higher FOV | |
intrinsics = torch.tensor([[focal_length, 0, 0.5], [0, focal_length, 0.5], [0, 0, 1]], device=device) | |
forward_pose = torch.cat([forward_cam2world_pose.reshape(-1, 16), intrinsics.reshape(-1, 9)], 1) | |
elif cfg == 'edge2car': | |
neural_rendering_resolution = 64 | |
data_type= 'edge' | |
else: | |
print('Invalid cfg') | |
save_dir = Path(outdir) | |
if isinstance(input,str): | |
input_label =np.asarray( PIL.Image.open(input)) | |
else: | |
input_label=np.asarray(input) | |
input_label=colormap2labelmap(input_label) | |
input_label=checklabelmap(input_label) | |
input_label = np.asarray(input_label).astype(np.uint8) | |
input_label = torch.from_numpy(input_label).unsqueeze(0).unsqueeze(0).to(device) | |
input_pose = forward_pose.to(device) | |
# Generate videos | |
z = torch.from_numpy(np.random.RandomState(int(random_seed)).randn(1, G.z_dim).astype('float32')).to(device) | |
with torch.no_grad(): | |
ws = G.mapping(z, input_pose, {'mask': input_label, 'pose': input_pose}) | |
out = G.synthesis(ws, input_pose, noise_mode='const', neural_rendering_resolution=neural_rendering_resolution) | |
image_color = ((out['image'][0].permute(1, 2, 0).cpu().numpy().clip(-1, 1) + 1) * 127.5).astype(np.uint8) | |
image_seg = color_mask(torch.argmax(out['semantic'][0], dim=0).cpu().numpy()).astype(np.uint8) | |
mesh_trimesh = trimesh.Trimesh(*extract_geometry(G, ws, resolution=mesh_resolution, threshold=50.)) | |
verts_np = np.array(mesh_trimesh.vertices) | |
colors = torch.zeros((verts_np.shape[0], 3), device=device) | |
semantic_colors = torch.zeros((verts_np.shape[0], 6), device=device) | |
samples_color = torch.tensor(verts_np, device=device).unsqueeze(0).float() | |
head = 0 | |
max_batch = 10000000 | |
with tqdm(total = verts_np.shape[0]) as pbar: | |
with torch.no_grad(): | |
while head < verts_np.shape[0]: | |
torch.manual_seed(0) | |
out = G.sample_mixed(samples_color[:, head:head+max_batch], None, ws, truncation_psi=truncation_psi, noise_mode='const') | |
# sigma = out['sigma'] | |
colors[head:head+max_batch, :] = out['rgb'][0,:,:3] | |
seg = out['rgb'][0, :, 32:32+6] | |
semantic_colors[head:head+max_batch, :] = seg | |
# semantics[:, head:head+max_batch] = out['semantic'] | |
head += max_batch | |
pbar.update(max_batch) | |
semantic_colors = torch.tensor(color_list,device=device)[torch.argmax(semantic_colors, dim=-1)] | |
mesh_trimesh.visual.vertex_colors = semantic_colors.cpu().numpy().astype(np.uint8) | |
frames, frames_label = render_video(G, ws, intrinsics, num_frames = num_frames, pitch_range = 0.25, yaw_range = 0.35, neural_rendering_resolution=neural_rendering_resolution, device=device) | |
# Save the video | |
video=os.path.join(save_dir ,f'{cfg}_color.mp4') | |
video_label=os.path.join(save_dir,f'{cfg}_label.mp4') | |
imageio.mimsave(video, frames, fps=fps) | |
imageio.mimsave(video_label, frames_label, fps=fps), | |
fig_mesh=return_plot_go(mesh_trimesh) | |
return fig_mesh,image_color,image_seg,video,video_label | |
title="3D-aware Conditional Image Synthesis" | |
desc=f''' | |
[Arxiv: "3D-aware Conditional Image Synthesis".](https://arxiv.org/abs/2302.08509) | |
[Project Page.](https://www.cs.cmu.edu/~pix2pix3D/) | |
[For the official implementation.](https://github.com/dunbar12138/pix2pix3D) | |
### Future Work based on interest | |
- Adding new models for new type objects | |
- New Customization | |
It is running on {device} | |
The process can take long time.Especially ,To generate videos and the time of process depends the number of frames,Mesh Resolution and current compiler device. | |
''' | |
demo_inputs=[ | |
gr.Dropdown(choices=["seg2cat"],label="Choose Model",value="seg2cat"), | |
gr.Image(type="filepath",shape=(512, 512),label="Mask"), | |
gr.Slider( minimum=0, maximum=2,label='Truncation PSI',value=1), | |
gr.Slider( minimum=32, maximum=512,label='Mesh Resolution',value=32), | |
gr.Slider( minimum=0, maximum=2**16,label='Seed',value=128), | |
gr.Slider( minimum=10, maximum=120,label='FPS',value=30), | |
gr.Slider( minimum=10, maximum=120,label='The Number of Frames',value=30), | |
] | |
demo_outputs=[ | |
gr.Plot(label="Generated Mesh"), | |
gr.Image(type="pil",shape=(256,256),label="Generated Image"), | |
gr.Image(type="pil",shape=(256,256),label="Generated LabelMap"), | |
gr.Video(label="Generated Video ") , | |
gr.Video(label="Generated Label Video ") | |
] | |
examples = [ | |
["seg2cat", "img.png", 1, 32, 128, 30, 30], | |
["seg2cat", "img2.png", 1, 32, 128, 30, 30], | |
["seg2cat", "img3.png", 1, 32, 128, 30, 30], | |
] | |
demo_app = gr.Interface( | |
fn=get_all, | |
inputs=demo_inputs, | |
outputs=demo_outputs, | |
cache_examples=True, | |
title=title, | |
theme="huggingface", | |
description=desc, | |
examples=examples, | |
) | |
demo_app.launch(debug=True, enable_queue=True) | |