dreamgaussian4d / lgm /infer_demo.py
jiaweir
optimize
cdc7dcc
import os
import tyro
import glob
import imageio
import numpy as np
import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms.functional as TF
from safetensors.torch import load_file
import kiui
from kiui.op import recenter
from kiui.cam import orbit_camera
from core.options import AllConfigs, Options
from core.models import LGM
from mvdream.pipeline_mvdream import MVDreamPipeline
import cv2
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
# opt = tyro.cli(AllConfigs)
# # model
# model = LGM(opt)
# # resume pretrained checkpoint
# if opt.resume is not None:
# if opt.resume.endswith('safetensors'):
# ckpt = load_file(opt.resume, device='cpu')
# else:
# ckpt = torch.load(opt.resume, map_location='cpu')
# model.load_state_dict(ckpt, strict=False)
# print(f'[INFO] Loaded checkpoint from {opt.resume}')
# else:
# print(f'[WARN] model randomly initialized, are you sure?')
# # device
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# model = model.half().to(device)
# model.eval()
# process function
def process(opt: Options, path, pipe, model, rays_embeddings, seed):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
tan_half_fov = np.tan(0.5 * np.deg2rad(opt.fovy))
proj_matrix = torch.zeros(4, 4, dtype=torch.float32, device=device)
proj_matrix[0, 0] = 1 / tan_half_fov
proj_matrix[1, 1] = 1 / tan_half_fov
proj_matrix[2, 2] = (opt.zfar + opt.znear) / (opt.zfar - opt.znear)
proj_matrix[3, 2] = - (opt.zfar * opt.znear) / (opt.zfar - opt.znear)
proj_matrix[2, 3] = 1
name = os.path.splitext(os.path.basename(path))[0]
print(f'[INFO] Processing {path} --> {name}')
os.makedirs('vis_data', exist_ok=True)
os.makedirs('logs', exist_ok=True)
image = kiui.read_image(path, mode='uint8')
# generate mv
image = image.astype(np.float32) / 255.0
# rgba to rgb white bg
if image.shape[-1] == 4:
image = image[..., :3] * image[..., 3:4] + (1 - image[..., 3:4])
generator = torch.manual_seed(seed)
mv_image = pipe('', image, guidance_scale=5.0, num_inference_steps=30, elevation=0, generator=generator)
mv_image = np.stack([mv_image[1], mv_image[2], mv_image[3], mv_image[0]], axis=0) # [4, 256, 256, 3], float32
# generate gaussians
input_image = torch.from_numpy(mv_image).permute(0, 3, 1, 2).float().to(device) # [4, 3, 256, 256]
input_image = F.interpolate(input_image, size=(opt.input_size, opt.input_size), mode='bilinear', align_corners=False)
input_image = TF.normalize(input_image, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
input_image = torch.cat([input_image, rays_embeddings], dim=1).unsqueeze(0) # [1, 4, 9, H, W]
with torch.inference_mode():
############## align azimuth #####################
with torch.autocast(device_type='cuda', dtype=torch.float16):
# generate gaussians
gaussians = model.forward_gaussians(input_image)
best_azi = 0
best_diff = 1e8
for v, azi in enumerate(np.arange(-180, 180, 1)):
cam_poses = torch.from_numpy(orbit_camera(0, azi, radius=opt.cam_radius, opengl=True)).unsqueeze(0).to(device)
cam_poses[:, :3, 1:3] *= -1 # invert up & forward direction
# cameras needed by gaussian rasterizer
cam_view = torch.inverse(cam_poses).transpose(1, 2) # [V, 4, 4]
cam_view_proj = cam_view @ proj_matrix # [V, 4, 4]
cam_pos = - cam_poses[:, :3, 3] # [V, 3]
# scale = min(azi / 360, 1)
scale = 1
result = model.gs.render(gaussians, cam_view.unsqueeze(0), cam_view_proj.unsqueeze(0), cam_pos.unsqueeze(0), scale_modifier=scale)
rendered_image = result['image']
rendered_image = rendered_image.squeeze(1).permute(0,2,3,1).squeeze(0).contiguous().float().cpu().numpy()
rendered_image = cv2.resize(rendered_image, (image.shape[0], image.shape[1]), interpolation=cv2.INTER_AREA)
diff = np.mean((rendered_image- image) ** 2)
if diff < best_diff:
best_diff = diff
best_azi = azi
print("Best aligned azimuth: ", best_azi)
mv_image = []
for v, azi in enumerate([0, 90, 180, 270]):
cam_poses = torch.from_numpy(orbit_camera(0, azi + best_azi, radius=opt.cam_radius, opengl=True)).unsqueeze(0).to(device)
cam_poses[:, :3, 1:3] *= -1 # invert up & forward direction
# cameras needed by gaussian rasterizer
cam_view = torch.inverse(cam_poses).transpose(1, 2) # [V, 4, 4]
cam_view_proj = cam_view @ proj_matrix # [V, 4, 4]
cam_pos = - cam_poses[:, :3, 3] # [V, 3]
# scale = min(azi / 360, 1)
scale = 1
result = model.gs.render(gaussians, cam_view.unsqueeze(0), cam_view_proj.unsqueeze(0), cam_pos.unsqueeze(0), scale_modifier=scale)
rendered_image = result['image']
rendered_image = rendered_image.squeeze(1)
rendered_image = F.interpolate(rendered_image, (256, 256))
rendered_image = rendered_image.permute(0,2,3,1).contiguous().float().cpu().numpy()
mv_image.append(rendered_image)
mv_image = np.concatenate(mv_image, axis=0)
input_image = torch.from_numpy(mv_image).permute(0, 3, 1, 2).float().to(device) # [4, 3, 256, 256]
input_image = F.interpolate(input_image, size=(opt.input_size, opt.input_size), mode='bilinear', align_corners=False)
input_image = TF.normalize(input_image, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
input_image = torch.cat([input_image, rays_embeddings], dim=1).unsqueeze(0) # [1, 4, 9, H, W]
################################
with torch.autocast(device_type='cuda', dtype=torch.float16):
# generate gaussians
gaussians, gaussians_orig_res = model.forward_gaussians_downsample(input_image)
# save gaussians
model.gs.save_ply(gaussians, os.path.join('logs', name + '_model.ply'))
# render 360 video
images = []
elevation = 0
azimuth = np.arange(0, 360, 2, dtype=np.int32)
for azi in tqdm.tqdm(azimuth):
cam_poses = torch.from_numpy(orbit_camera(elevation, azi, radius=opt.cam_radius, opengl=True)).unsqueeze(0).to(device)
cam_poses[:, :3, 1:3] *= -1 # invert up & forward direction
# cameras needed by gaussian rasterizer
cam_view = torch.inverse(cam_poses).transpose(1, 2) # [V, 4, 4]
cam_view_proj = cam_view @ proj_matrix # [V, 4, 4]
cam_pos = - cam_poses[:, :3, 3] # [V, 3]
image = model.gs.render(gaussians_orig_res, cam_view.unsqueeze(0), cam_view_proj.unsqueeze(0), cam_pos.unsqueeze(0), scale_modifier=1)['image']
images.append((image.squeeze(1).permute(0,2,3,1).contiguous().float().cpu().numpy() * 255).astype(np.uint8))
images = np.concatenate(images, axis=0)
imageio.mimwrite(os.path.join('vis_data', name + '_static.mp4'), images, fps=30)