bluestyle97's picture
Update freesplatter/webui/runner.py
944d7dc verified
import os
import json
import uuid
import time
import rembg
import numpy as np
import trimesh
import torch
import fpsample
import matplotlib.pyplot as plt
cmap = plt.get_cmap("hsv")
from torchvision.transforms import v2
from pytorch_lightning import seed_everything
from PIL import Image
from omegaconf import OmegaConf
from einops import rearrange
from scipy.spatial.transform import Rotation
from safetensors import safe_open
from huggingface_hub import hf_hub_download, snapshot_download
from transformers import AutoModelForImageSegmentation
from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler
from freesplatter.hunyuan.hunyuan3d_mvd_std_pipeline import HunYuan3D_MVD_Std_Pipeline
from freesplatter.utils.mesh_optim import optimize_mesh
from freesplatter.utils.camera_util import *
from freesplatter.utils.recon_util import *
from freesplatter.utils.infer_util import *
from freesplatter.webui.camera_viewer.visualizer import CameraVisualizer
def inv_sigmoid(x: torch.Tensor) -> torch.Tensor:
return torch.log(x / (1.0 - x))
def save_gaussian(latent, gs_vis_path, model, opacity_threshold=None, pad_2dgs_scale=True):
if latent.ndim == 3:
latent = latent[0]
sh_dim = model.sh_dim
scale_dim = 2 if model.use_2dgs else 3
xyz, features, opacity, scaling, rotation = latent.split([3, sh_dim, 1, scale_dim, 4], dim=-1)
features = features.reshape(features.shape[0], sh_dim//3, 3)
if opacity_threshold is not None:
index = torch.nonzero(opacity.sigmoid() > opacity_threshold)[:, 0]
xyz = xyz[index]
features = features[index]
opacity = opacity[index]
scaling = scaling[index]
rotation = rotation[index]
# transform gaussians from reference view to world view
cam2world = create_camera_to_world(torch.tensor([0, -2, 0]), camera_system='opencv').to(latent)
R, T = cam2world[:3, :3], cam2world[:3, 3].reshape(1, 3)
xyz = xyz @ R.T + T
rotation = rotation.detach().cpu().numpy()
rotation = Rotation.from_quat(rotation[:, [1, 2, 3, 0]]).as_matrix()
rotation = R.detach().cpu().numpy() @ rotation
rotation = Rotation.from_matrix(rotation).as_quat()[:, [3, 0, 1, 2]]
rotation = torch.from_numpy(rotation).to(latent)
# pad 2DGS with an additional z-scale for visualization
if scaling.shape[-1] == 2 and pad_2dgs_scale:
z_scaling = inv_sigmoid(torch.ones_like(scaling[:, :1]) * 0.001)
scaling = torch.cat([scaling, z_scaling], dim=-1)
pc_vis = model.gs_renderer.gaussian_model.set_data(
xyz.float(), features.float(), scaling.float(), rotation.float(), opacity.float())
pc_vis.save_ply_vis(gs_vis_path)
class FreeSplatterRunner:
def __init__(self, device):
self.device = device
# background remover
self.rembg = AutoModelForImageSegmentation.from_pretrained(
# "ZhengPeng7/BiRefNet",
"briaai/RMBG-2.0",
trust_remote_code=True,
).to(device)
self.rembg.eval()
# self.rembg = rembg.new_session('birefnet-general')
# diffusion models
pipeline = DiffusionPipeline.from_pretrained(
"sudo-ai/zero123plus-v1.1",
custom_pipeline="sudo-ai/zero123plus-pipeline",
torch_dtype=torch.float16,
)
pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(
pipeline.scheduler.config, timestep_spacing='trailing'
)
self.zero123plus_v11 = pipeline.to(device)
pipeline = DiffusionPipeline.from_pretrained(
"sudo-ai/zero123plus-v1.2",
custom_pipeline="sudo-ai/zero123plus-pipeline",
torch_dtype=torch.float16,
)
pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(
pipeline.scheduler.config, timestep_spacing='trailing'
)
self.zero123plus_v12 = pipeline.to(device)
download_dir = snapshot_download('tencent/Hunyuan3D-1', repo_type='model')
pipeline = HunYuan3D_MVD_Std_Pipeline.from_pretrained(
os.path.join(download_dir, 'mvd_std'),
torch_dtype=torch.float16,
use_safetensors=True,
)
self.hunyuan3d_mvd_std = pipeline.to(device)
# freesplatter
config_file = 'configs/freesplatter-object.yaml'
ckpt_path = hf_hub_download('TencentARC/FreeSplatter', repo_type='model', filename='freesplatter-object.safetensors')
model = instantiate_from_config(OmegaConf.load(config_file).model)
state_dict = {}
with safe_open(ckpt_path, framework="pt", device="cpu") as f:
for key in f.keys():
state_dict[key] = f.get_tensor(key)
model.load_state_dict(state_dict, strict=True)
self.freesplatter = model.eval().to(device)
config_file = 'configs/freesplatter-object-2dgs.yaml'
ckpt_path = hf_hub_download('TencentARC/FreeSplatter', repo_type='model', filename='freesplatter-object-2dgs.safetensors')
model = instantiate_from_config(OmegaConf.load(config_file).model)
state_dict = {}
with safe_open(ckpt_path, framework="pt", device="cpu") as f:
for key in f.keys():
state_dict[key] = f.get_tensor(key)
model.load_state_dict(state_dict, strict=True)
self.freesplatter_2dgs = model.eval().to(device)
config_file = 'configs/freesplatter-scene.yaml'
ckpt_path = hf_hub_download('TencentARC/FreeSplatter', repo_type='model', filename='freesplatter-scene.safetensors')
model = instantiate_from_config(OmegaConf.load(config_file).model)
state_dict = {}
with safe_open(ckpt_path, framework="pt", device="cpu") as f:
for key in f.keys():
state_dict[key] = f.get_tensor(key)
model.load_state_dict(state_dict, strict=True)
self.freesplatter_scene = model.eval().to(device)
@torch.inference_mode()
def run_segmentation(
self,
image,
do_rembg=True,
):
if do_rembg:
image = remove_background(image, self.rembg)
return image
def run_img_to_3d(
self,
image,
model='Zero123++ v1.2',
diffusion_steps=30,
guidance_scale=4.0,
seed=42,
view_indices=[],
gs_type='2DGS',
mesh_reduction=0.5,
cache_dir=None,
):
image_rgba = self.run_segmentation(image)
res = [image_rgba]
yield res + [None] * (6 - len(res))
self.output_dir = os.path.join(cache_dir, f'output_{uuid.uuid4()}')
os.makedirs(self.output_dir, exist_ok=True)
# image-to-multiview
input_image = resize_foreground(image_rgba, 0.9)
seed_everything(seed)
if model == 'Zero123++ v1.1':
output_image = self.zero123plus_v11(
input_image,
num_inference_steps=diffusion_steps,
guidance_scale=guidance_scale,
).images[0]
elif model == 'Zero123++ v1.2':
output_image = self.zero123plus_v12(
input_image,
num_inference_steps=diffusion_steps,
guidance_scale=guidance_scale,
).images[0]
elif model == 'Hunyuan3D Std':
output_image = self.hunyuan3d_mvd_std(
input_image,
num_inference_steps=diffusion_steps,
guidance_scale=guidance_scale,
guidance_curve=lambda t:2.0,
).images[0]
else:
raise ValueError(f'Unknown model: {model}')
# preprocess images
image, alpha = rgba_to_white_background(input_image)
image = v2.functional.resize(image, 512, interpolation=3, antialias=True).clamp(0, 1)
alpha = v2.functional.resize(alpha, 512, interpolation=0, antialias=True).clamp(0, 1)
output_image_rgba = remove_background(output_image, self.rembg)
if 'Zero123++' in model:
images, alphas = rgba_to_white_background(output_image_rgba)
else:
_, alphas = rgba_to_white_background(output_image_rgba)
images = torch.from_numpy(np.asarray(output_image) / 255.0).float()
images = rearrange(images, 'h w c -> c h w')
images = rearrange(images, 'c (n h) (m w) -> (n m) c h w', n=3, m=2)
alphas = rearrange(alphas, 'c (n h) (m w) -> (n m) c h w', n=3, m=2)
if model == 'Hunyuan3D Std':
images = images[[0, 2, 4, 5, 3, 1]]
alphas = alphas[[0, 2, 4, 5, 3, 1]]
images_vis = v2.functional.to_pil_image(rearrange(images, 'nm c h w -> c h (nm w)'))
res += [images_vis]
yield res + [None] * (6 - len(res))
images = v2.functional.resize(images, 512, interpolation=3, antialias=True).clamp(0, 1)
alphas = v2.functional.resize(alphas, 512, interpolation=0, antialias=True).clamp(0, 1)
images = torch.cat([image.unsqueeze(0), images], dim=0) # 7 x 3 x 512 x 512
alphas = torch.cat([alpha.unsqueeze(0), alphas], dim=0) # 7 x 1 x 512 x 512
# run reconstruction
view_indices = [1, 2, 3, 4, 5, 6] if len(view_indices) == 0 else view_indices
images, alphas = images[view_indices], alphas[view_indices]
legends = [f'V{i}' if i != 0 else 'Input' for i in view_indices]
for item in self.run_freesplatter_object(
images, alphas, legends=legends, gs_type=gs_type, mesh_reduction=mesh_reduction):
res += [item]
yield res + [None] * (6 - len(res))
def run_views_to_3d(
self,
image_files,
do_rembg=False,
gs_type='2DGS',
mesh_reduction=0.5,
cache_dir=None,
):
self.output_dir = os.path.join(cache_dir, f'output_{uuid.uuid4()}')
os.makedirs(self.output_dir, exist_ok=True)
# preprocesss images
images, alphas = [], []
for image_file in image_files:
if isinstance(image_file, tuple):
image_file = image_file[0]
image = Image.open(image_file)
w, h = image.size
image_rgba = self.run_segmentation(image)
if image.mode == 'RGBA':
image, alpha = rgba_to_white_background(image_rgba)
image = v2.functional.center_crop(image, min(h, w))
alpha = v2.functional.center_crop(alpha, min(h, w))
else:
image_rgba = resize_foreground(image_rgba, 0.9)
image_rgba.save('test.png')
image, alpha = rgba_to_white_background(image_rgba)
image = v2.functional.resize(image, 512, interpolation=3, antialias=True).clamp(0, 1)
alpha = v2.functional.resize(alpha, 512, interpolation=0, antialias=True).clamp(0, 1)
images.append(image)
alphas.append(alpha)
images = torch.stack(images, dim=0)
alphas = torch.stack(alphas, dim=0)
images_vis = v2.functional.to_pil_image(rearrange(images, 'n c h w -> c h (n w)'))
# run reconstruction
legends = [f'V{i}' for i in range(1, 1+len(images))]
gs_vis_path, video_path, mesh_fine_path, fig = self.run_freesplatter_object(
images, alphas, legends=legends, gs_type=gs_type, mesh_reduction=mesh_reduction)
return images_vis, gs_vis_path, video_path, mesh_fine_path, fig
def run_freesplatter_object(
self,
images,
alphas,
legends=None,
gs_type='2DGS',
mesh_reduction=0.5,
):
device = self.device
freesplatter = self.freesplatter_2dgs if gs_type == '2DGS' else self.freesplatter
images, alphas = images.to(device), alphas.to(device)
t0 = time.time()
with torch.inference_mode():
gaussians = freesplatter.forward_gaussians(images.unsqueeze(0))
t1 = time.time()
# estimate camera parameters and visualize
c2ws_pred, focals_pred = freesplatter.estimate_poses(images, gaussians, masks=alphas, use_first_focal=True, pnp_iter=10)
fig = self.visualize_cameras_object(images, c2ws_pred, focals_pred, legends=legends)
t2 = time.time()
yield fig
# save gaussians
gs_vis_path = os.path.join(self.output_dir, 'gs_vis.ply')
save_gaussian(gaussians, gs_vis_path, freesplatter, opacity_threshold=5e-3, pad_2dgs_scale=True)
print(f'Save gaussian at {gs_vis_path}')
yield gs_vis_path
# render video
with torch.inference_mode():
c2ws_video = get_circular_cameras(N=120, elevation=0, radius=2.0, normalize=True).to(device)
fx = fy = focals_pred.mean() / 512.0
cx = cy = torch.ones_like(fx) * 0.5
fxfycxcy_video = torch.tensor([fx, fy, cx, cy]).unsqueeze(0).repeat(c2ws_video.shape[0], 1).to(device)
video_frames = freesplatter.forward_renderer(
gaussians,
c2ws_video.unsqueeze(0),
fxfycxcy_video.unsqueeze(0),
)['image'][0].clamp(0, 1)
video_path = os.path.join(self.output_dir, 'gs.mp4')
save_video(video_frames, video_path, fps=30)
print(f'Save video at {video_path}')
t3 = time.time()
yield video_path
# extract mesh
with torch.inference_mode():
c2ws_fusion = get_fibonacci_cameras(N=120, radius=2.0)
c2ws_fusion, _ = normalize_cameras(c2ws_fusion, camera_position=torch.tensor([0., -2., 0.]), camera_system='opencv')
c2ws_fusion = c2ws_fusion.to(device)
c2ws_fusion_reference = torch.linalg.inv(c2ws_fusion[0:1]) @ c2ws_fusion
fx = fy = focals_pred.mean() / 512.0
cx = cy = torch.ones_like(fx) * 0.5
fov = np.rad2deg(np.arctan(0.5 / fx.item())) * 2
fxfycxcy_fusion = torch.tensor([fx, fy, cx, cy]).unsqueeze(0).repeat(c2ws_fusion.shape[0], 1).to(device)
fusion_render_results = freesplatter.forward_renderer(
gaussians,
c2ws_fusion_reference.unsqueeze(0),
fxfycxcy_fusion.unsqueeze(0),
)
images_fusion = fusion_render_results['image'][0].clamp(0, 1).permute(0, 2, 3, 1)
alphas_fusion = fusion_render_results['alpha'][0].permute(0, 2, 3, 1)
depths_fusion = fusion_render_results['depth'][0].permute(0, 2, 3, 1)
fusion_images = (images_fusion.detach().cpu().numpy()*255).clip(0, 255).astype(np.uint8)
fusion_depths = depths_fusion.detach().cpu().numpy()
fusion_alphas = alphas_fusion.detach().cpu().numpy()
fusion_masks = (fusion_alphas > 1e-2).astype(np.uint8)
fusion_depths = fusion_depths * fusion_masks - np.ones_like(fusion_depths) * (1 - fusion_masks)
fusion_c2ws = c2ws_fusion.detach().cpu().numpy()
mesh_path = os.path.join(self.output_dir, 'mesh.obj')
rgbd_to_mesh(
fusion_images, fusion_depths, fusion_c2ws, fov, mesh_path, cam_elev_thr=-90) # use all angles for tsdf fusion
print(f'Save mesh at {mesh_path}')
t4 = time.time()
# optimize texture
cam_pos = c2ws_fusion[:, :3, 3].cpu().numpy()
cam_inds = torch.from_numpy(fpsample.fps_sampling(cam_pos, 16).astype(int)).to(device=device)
alphas_bake = alphas_fusion[cam_inds]
images_bake = (images_fusion[cam_inds] - (1 - alphas_bake)) / alphas_bake.clamp(min=1e-6)
fxfycxcy = fxfycxcy_fusion[cam_inds].clone()
intrinsics = torch.eye(3).unsqueeze(0).repeat(len(cam_inds), 1, 1).to(fxfycxcy)
intrinsics[:, 0, 0] = fxfycxcy[:, 0]
intrinsics[:, 0, 2] = fxfycxcy[:, 2]
intrinsics[:, 1, 1] = fxfycxcy[:, 1]
intrinsics[:, 1, 2] = fxfycxcy[:, 3]
out_mesh = trimesh.load(str(mesh_path), process=False)
out_mesh = optimize_mesh(
out_mesh,
images_bake,
alphas_bake.squeeze(-1),
c2ws_fusion[cam_inds].inverse(),
intrinsics,
simplify=mesh_reduction,
verbose=False
)
mesh_fine_path = os.path.join(self.output_dir, 'mesh.glb')
out_mesh.export(mesh_fine_path)
print(f"Save optimized mesh at {mesh_fine_path}")
t5 = time.time()
print(f'Generate Gaussians: {t1-t0:.2f} seconds.')
print(f'Estimate poses: {t2-t1:.2f} seconds.')
print(f'Generate video: {t3-t2:.2f} seconds.')
print(f'Generate mesh: {t4-t3:.2f} seconds.')
print(f'Optimize mesh: {t5-t4:.2f} seconds.')
yield mesh_fine_path
def visualize_cameras_object(
self,
images,
c2ws,
focal_length,
legends=None,
):
images = v2.functional.resize(images, 128, interpolation=3, antialias=True).clamp(0, 1)
images = (images.permute(0, 2, 3, 1).detach().cpu().numpy() * 255).astype(np.uint8)
cam2world = create_camera_to_world(torch.tensor([0, -2, 0]), camera_system='opencv').to(c2ws)
transform = cam2world @ torch.linalg.inv(c2ws[0:1])
c2ws = transform @ c2ws
c2ws = c2ws.detach().cpu().numpy()
c2ws[:, :, 1:3] *= -1 # opencv to opengl
focal_length = focal_length.mean().detach().cpu().numpy()
fov = np.rad2deg(np.arctan(256.0 / focal_length)) * 2
colors = [cmap(i / len(images))[:3] for i in range(len(images))]
legends = [None] * len(images) if legends is None else legends
viz = CameraVisualizer(c2ws, legends, colors, images=images)
fig = viz.update_figure(
3,
height=320,
line_width=5,
base_radius=1,
zoom_scale=1,
fov_deg=fov,
show_grid=True,
show_ticklabels=True,
show_background=True,
y_up=False,
)
return fig
# FreeSplatter-S
def run_views_to_scene(
self,
image1,
image2,
cache_dir=None,
):
self.output_dir = os.path.join(cache_dir, f'output_{uuid.uuid4()}')
os.makedirs(self.output_dir, exist_ok=True)
# preprocesss images
images = []
for image in [image1, image2]:
w, h = image.size
image = torch.from_numpy(np.asarray(image) / 255.0).float()
image = rearrange(image, 'h w c -> c h w')
image = v2.functional.center_crop(image, min(h, w))
image = v2.functional.resize(image, 512, interpolation=3, antialias=True).clamp(0, 1)
images.append(image)
images = torch.stack(images, dim=0)
images_vis = v2.functional.to_pil_image(rearrange(images, 'n c h w -> c h (n w)'))
# run reconstruction
legends = [f'V{i}' for i in range(1, 1+len(images))]
gs_vis_path, video_path, fig = self.run_freesplatter_scene(images, legends=legends)
return images_vis, gs_vis_path, video_path, fig
def run_freesplatter_scene(
self,
images,
legends=None,
):
freesplatter = self.freesplatter_scene
device = self.device
images = images.to(device)
t0 = time.time()
with torch.inference_mode():
gaussians = freesplatter.forward_gaussians(images.unsqueeze(0))
t1 = time.time()
# estimate camera parameters
c2ws_pred, focals_pred = freesplatter.estimate_poses(images, gaussians, use_first_focal=True, pnp_iter=10)
# rescale cameras to make the baseline equal to 1.0
baseline_pred = (c2ws_pred[:, :3, 3] - c2ws_pred[:1, :3, 3]).norm() + 1e-2
scale_factor = 1.0 / baseline_pred
c2ws_pred = c2ws_pred.clone()
c2ws_pred[:, :3, 3] *= scale_factor
# visualize cameras
fig = self.visualize_cameras_scene(images, c2ws_pred, focals_pred, legends=legends)
t2 = time.time()
# save gaussians
gs_vis_path = os.path.join(self.output_dir, 'gs_vis.ply')
save_gaussian(gaussians, gs_vis_path, freesplatter, opacity_threshold=5e-3)
print(f'Save gaussian at {gs_vis_path}')
# render video
with torch.inference_mode():
c2ws_video = generate_interpolated_path(c2ws_pred.detach().cpu().numpy()[:, :3, :], n_interp=120)
c2ws_video = torch.cat([
torch.from_numpy(c2ws_video),
torch.tensor([0, 0, 0, 1]).reshape(1, 1, 4).repeat(c2ws_video.shape[0], 1, 1)
], dim=1).to(gaussians)
fx = fy = focals_pred.mean() / 512.0
cx = cy = torch.ones_like(fx) * 0.5
fxfycxcy_video = torch.tensor([fx, fy, cx, cy]).unsqueeze(0).repeat(c2ws_video.shape[0], 1).to(device)
video_frames = freesplatter.forward_renderer(
gaussians,
c2ws_video.unsqueeze(0),
fxfycxcy_video.unsqueeze(0),
rescale=scale_factor.reshape(1).to(gaussians)
)['image'][0].clamp(0, 1)
video_path = os.path.join(self.output_dir, 'gs.mp4')
save_video(video_frames, video_path, fps=30)
print(f'Save video at {video_path}')
t3 = time.time()
print(f'Generate Gaussians: {t1-t0:.2f} seconds.')
print(f'Estimate poses: {t2-t1:.2f} seconds.')
print(f'Generate video: {t3-t2:.2f} seconds.')
return gs_vis_path, video_path, fig
def visualize_cameras_scene(
self,
images,
c2ws,
focal_length,
legends=None,
):
images = v2.functional.resize(images, 128, interpolation=3, antialias=True).clamp(0, 1)
images = (images.permute(0, 2, 3, 1).detach().cpu().numpy() * 255).astype(np.uint8)
c2ws = c2ws.detach().cpu().numpy()
c2ws[:, :, 1:3] *= -1
focal_length = focal_length.mean().detach().cpu().numpy()
fov = np.rad2deg(np.arctan(256.0 / focal_length)) * 2
colors = [cmap(i / len(images))[:3] for i in range(len(images))]
legends = [None] * len(images) if legends is None else legends
viz = CameraVisualizer(c2ws, legends, colors, images=images)
fig = viz.update_figure(
2,
height=320,
line_width=5,
base_radius=1,
zoom_scale=1,
fov_deg=fov,
show_grid=True,
show_ticklabels=True,
show_background=True,
y_up=False,
)
return fig