vggt / gradio_util.py
JianyuanWang's picture
update
d35072c
raw
history blame
12.6 kB
# import os
import trimesh
# import open3d as o3d
import gradio as gr
import numpy as np
import matplotlib
from scipy.spatial.transform import Rotation
import copy
import cv2
import os
def demo_predictions_to_glb(predictions, conf_thres=3.0, filter_by_frames="all", mask_black_bg=False, show_cam=True, mask_sky=False, target_dir=None) -> trimesh.Scene:
"""
Converts VGG SFM predictions to a 3D scene represented as a GLB.
Args:
predictions (dict): A dictionary containing model predictions.
Returns:
trimesh.Scene: A 3D scene object.
"""
# Convert predictions to numpy arrays
# pred_extrinsic_list', 'pred_world_points', 'pred_world_points_conf', 'images', 'last_pred_extrinsic
if conf_thres is None:
conf_thres = 0.0
print("Building GLB scene")
selected_frame_idx = None
if filter_by_frames != "all":
try:
# Extract the index part before the colon
selected_frame_idx = int(filter_by_frames.split(":")[0])
except (ValueError, IndexError):
pass
pred_world_points = predictions["pred_world_points"][0] # remove batch dimension
pred_world_points_conf = predictions["pred_world_points_conf"][0]
images = predictions["images"][0]
last_pred_extrinsic = predictions["last_pred_extrinsic"][0]
if mask_sky:
if target_dir is not None:
import onnxruntime
skyseg_session = None
target_dir_images = target_dir + "/images"
image_list = sorted(os.listdir(target_dir_images))
sky_mask_list = []
# Get the shape of pred_world_points_conf to match
S, H, W = pred_world_points_conf.shape
for i, image_name in enumerate(image_list):
image_filepath = os.path.join(target_dir_images, image_name)
mask_filepath = os.path.join(target_dir, "sky_masks", image_name)
# Check if mask already exists
if os.path.exists(mask_filepath):
# Load existing mask
sky_mask = cv2.imread(mask_filepath, cv2.IMREAD_GRAYSCALE)
else:
# Generate new mask
if skyseg_session is None:
skyseg_session = onnxruntime.InferenceSession("skyseg.onnx")
sky_mask = segment_sky(image_filepath, skyseg_session, mask_filepath)
# Resize mask to match H×W if needed
if sky_mask.shape[0] != H or sky_mask.shape[1] != W:
sky_mask = cv2.resize(sky_mask, (W, H))
# model_was_never_trained_on_single_image_or_oil_painting
# no overlap
# single view, youhua
# single view, catoon
sky_mask_list.append(sky_mask)
# Convert list to numpy array with shape S×H×W
sky_mask_array = np.array(sky_mask_list)
# Apply sky mask to confidence scores
sky_mask_binary = (sky_mask_array > 0.01).astype(np.float32)
pred_world_points_conf = pred_world_points_conf * sky_mask_binary
if selected_frame_idx is not None:
pred_world_points = pred_world_points[selected_frame_idx][None]
pred_world_points_conf = pred_world_points_conf[selected_frame_idx][None]
images = images[selected_frame_idx][None]
last_pred_extrinsic = last_pred_extrinsic[selected_frame_idx][None]
vertices_3d = pred_world_points.reshape(-1, 3)
colors_rgb = np.transpose(images, (0, 2, 3, 1)) #images.permute(0, 3, 1, 2)
colors_rgb = (colors_rgb.reshape(-1, 3) * 255).astype(np.uint8)
camera_matrices = last_pred_extrinsic
conf = pred_world_points_conf.reshape(-1)
conf_mask = conf > conf_thres
if mask_black_bg:
black_bg_mask = colors_rgb.sum(axis=1) >= 16
conf_mask = conf_mask & black_bg_mask
vertices_3d = vertices_3d[conf_mask]
colors_rgb = colors_rgb[conf_mask]
# resize_scale = 0.25
# vertices_3d = vertices_3d * resize_scale
# camera_matrices[:, :3, 3] = camera_matrices[:, :3, 3] * resize_scale
if vertices_3d is None or np.asarray(vertices_3d).size == 0:
vertices_3d = np.array([[1, 0, 0]])
colors_rgb = np.array([[255, 255, 255]])
scene_scale = 1
else:
# Calculate the 5th and 95th percentiles along each axis
lower_percentile = np.percentile(vertices_3d, 5, axis=0)
upper_percentile = np.percentile(vertices_3d, 95, axis=0)
# Calculate the diagonal length of the percentile bounding box
scene_scale = np.linalg.norm(upper_percentile - lower_percentile)
colormap = matplotlib.colormaps.get_cmap("gist_rainbow")
# Initialize a 3D scene
scene_3d = trimesh.Scene()
# Add point cloud data to the scene
point_cloud_data = trimesh.PointCloud(
vertices=vertices_3d, colors=colors_rgb
)
scene_3d.add_geometry(point_cloud_data)
# Prepare 4x4 matrices for camera extrinsics
num_cameras = len(camera_matrices)
extrinsics_matrices = np.zeros((num_cameras, 4, 4))
extrinsics_matrices[:, :3, :4] = camera_matrices
extrinsics_matrices[:, 3, 3] = 1
if show_cam:
# Add camera models to the scene
for i in range(num_cameras):
world_to_camera = extrinsics_matrices[i]
camera_to_world = np.linalg.inv(world_to_camera)
rgba_color = colormap(i / num_cameras)
current_color = tuple(int(255 * x) for x in rgba_color[:3])
integrate_camera_into_scene(
scene_3d, camera_to_world, current_color, scene_scale
)
# Align scene to the observation of the first camera
scene_3d = apply_scene_alignment(scene_3d, extrinsics_matrices)
print("GLB Scene built")
return scene_3d
def integrate_camera_into_scene(
scene: trimesh.Scene,
transform: np.ndarray,
face_colors: tuple,
scene_scale: float,
):
"""
Integrates a fake camera mesh into the 3D scene.
Args:
scene (trimesh.Scene): The 3D scene to add the camera model.
transform (np.ndarray): Transformation matrix for camera positioning.
face_colors (tuple): Color of the camera face.
scene_scale (float): Scale of the scene.
"""
cam_width = scene_scale * 0.05
cam_height = scene_scale * 0.1
# Create cone shape for camera
rot_45_degree = np.eye(4)
rot_45_degree[:3, :3] = Rotation.from_euler(
"z", 45, degrees=True
).as_matrix()
rot_45_degree[2, 3] = -cam_height
opengl_transform = get_opengl_conversion_matrix()
# Combine transformations
complete_transform = transform @ opengl_transform @ rot_45_degree
camera_cone_shape = trimesh.creation.cone(cam_width, cam_height, sections=4)
# Generate mesh for the camera
slight_rotation = np.eye(4)
slight_rotation[:3, :3] = Rotation.from_euler(
"z", 2, degrees=True
).as_matrix()
vertices_combined = np.concatenate(
[
camera_cone_shape.vertices,
0.95 * camera_cone_shape.vertices,
transform_points(slight_rotation, camera_cone_shape.vertices),
]
)
vertices_transformed = transform_points(
complete_transform, vertices_combined
)
mesh_faces = compute_camera_faces(camera_cone_shape)
# Add the camera mesh to the scene
camera_mesh = trimesh.Trimesh(
vertices=vertices_transformed, faces=mesh_faces
)
camera_mesh.visual.face_colors[:, :3] = face_colors
scene.add_geometry(camera_mesh)
def apply_scene_alignment(
scene_3d: trimesh.Scene, extrinsics_matrices: np.ndarray
) -> trimesh.Scene:
"""
Aligns the 3D scene based on the extrinsics of the first camera.
Args:
scene_3d (trimesh.Scene): The 3D scene to be aligned.
extrinsics_matrices (np.ndarray): Camera extrinsic matrices.
Returns:
trimesh.Scene: Aligned 3D scene.
"""
# Set transformations for scene alignment
opengl_conversion_matrix = get_opengl_conversion_matrix()
# Rotation matrix for alignment (180 degrees around the y-axis)
align_rotation = np.eye(4)
align_rotation[:3, :3] = Rotation.from_euler(
"y", 180, degrees=True
).as_matrix()
# Apply transformation
initial_transformation = (
np.linalg.inv(extrinsics_matrices[0])
@ opengl_conversion_matrix
@ align_rotation
)
scene_3d.apply_transform(initial_transformation)
return scene_3d
def get_opengl_conversion_matrix() -> np.ndarray:
"""
Constructs and returns the OpenGL conversion matrix.
Returns:
numpy.ndarray: A 4x4 OpenGL conversion matrix.
"""
# Create an identity matrix
matrix = np.identity(4)
# Flip the y and z axes
matrix[1, 1] = -1
matrix[2, 2] = -1
return matrix
def transform_points(
transformation: np.ndarray, points: np.ndarray, dim: int = None
) -> np.ndarray:
"""
Applies a 4x4 transformation to a set of points.
Args:
transformation (np.ndarray): Transformation matrix.
points (np.ndarray): Points to be transformed.
dim (int, optional): Dimension for reshaping the result.
Returns:
np.ndarray: Transformed points.
"""
points = np.asarray(points)
initial_shape = points.shape[:-1]
dim = dim or points.shape[-1]
# Apply transformation
transformation = transformation.swapaxes(
-1, -2
) # Transpose the transformation matrix
points = points @ transformation[..., :-1, :] + transformation[..., -1:, :]
# Reshape the result
result = points[..., :dim].reshape(*initial_shape, dim)
return result
def compute_camera_faces(cone_shape: trimesh.Trimesh) -> np.ndarray:
"""
Computes the faces for the camera mesh.
Args:
cone_shape (trimesh.Trimesh): The shape of the camera cone.
Returns:
np.ndarray: Array of faces for the camera mesh.
"""
# Create pseudo cameras
faces_list = []
num_vertices_cone = len(cone_shape.vertices)
for face in cone_shape.faces:
if 0 in face:
continue
v1, v2, v3 = face
v1_offset, v2_offset, v3_offset = face + num_vertices_cone
v1_offset_2, v2_offset_2, v3_offset_2 = face + 2 * num_vertices_cone
faces_list.extend(
[
(v1, v2, v2_offset),
(v1, v1_offset, v3),
(v3_offset, v2, v3),
(v1, v2, v2_offset_2),
(v1, v1_offset_2, v3),
(v3_offset_2, v2, v3),
]
)
faces_list += [(v3, v2, v1) for v1, v2, v3 in faces_list]
return np.array(faces_list)
def segment_sky(image_path, onnx_session, mask_filename=None):
assert mask_filename is not None
image = cv2.imread(image_path)
result_map = run_skyseg(onnx_session,[320,320],image)
# resize the result_map to the original image size
result_map_original = cv2.resize(result_map, (image.shape[1], image.shape[0]))
output_mask = np.zeros_like(result_map_original)
output_mask[result_map_original<1] = 1
output_mask = output_mask.astype(np.uint8) * 255
os.makedirs(os.path.dirname(mask_filename), exist_ok=True)
cv2.imwrite(mask_filename, output_mask)
return output_mask
def run_skyseg(onnx_session, input_size, image):
# Pre process:Resize, BGR->RGB, Transpose, PyTorch standardization, float32 cast
temp_image = copy.deepcopy(image)
resize_image = cv2.resize(temp_image, dsize=(input_size[0], input_size[1]))
x = cv2.cvtColor(resize_image, cv2.COLOR_BGR2RGB)
x = np.array(x, dtype=np.float32)
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
x = (x / 255 - mean) / std
x = x.transpose(2, 0, 1)
x = x.reshape(-1, 3, input_size[0], input_size[1]).astype('float32')
# Inference
input_name = onnx_session.get_inputs()[0].name
output_name = onnx_session.get_outputs()[0].name
onnx_result = onnx_session.run([output_name], {input_name: x})
# Post process
onnx_result = np.array(onnx_result).squeeze()
min_value = np.min(onnx_result)
max_value = np.max(onnx_result)
onnx_result = (onnx_result - min_value) / (max_value - min_value)
onnx_result *= 255
onnx_result = onnx_result.astype('uint8')
return onnx_result