PSHuman / utils /mesh_utils.py
fffiloni's picture
Migrated from GitHub
2252f3d verified
raw
history blame
7.32 kB
import torch
import numpy as np
import pymeshlab as ml
from pytorch3d.renderer import TexturesVertex
from pytorch3d.structures import Meshes
import torch
import torch.nn.functional as F
import trimesh
from pymeshlab import PercentageValue
import open3d as o3d
def tensor2variable(tensor, device):
# [1,23,3,3]
return torch.tensor(tensor, device=device, requires_grad=True)
def rot6d_to_rotmat(x):
"""Convert 6D rotation representation to 3x3 rotation matrix.
Based on Zhou et al., "On the Continuity of Rotation Representations in Neural Networks", CVPR 2019
Input:
(B,6) Batch of 6-D rotation representations
Output:
(B,3,3) Batch of corresponding rotation matrices
"""
x = x.view(-1, 3, 2)
a1 = x[:, :, 0]
a2 = x[:, :, 1]
b1 = F.normalize(a1)
b2 = F.normalize(a2 - torch.einsum("bi,bi->b", b1, a2).unsqueeze(-1) * b1)
b3 = torch.cross(b1, b2)
return torch.stack((b1, b2, b3), dim=-1)
def fix_vert_color_glb(mesh_path):
from pygltflib import GLTF2, Material, PbrMetallicRoughness
obj1 = GLTF2().load(mesh_path)
obj1.meshes[0].primitives[0].material = 0
obj1.materials.append(Material(
pbrMetallicRoughness = PbrMetallicRoughness(
baseColorFactor = [1.0, 1.0, 1.0, 1.0],
metallicFactor = 0.,
roughnessFactor = 1.0,
),
emissiveFactor = [0.0, 0.0, 0.0],
doubleSided = True,
))
obj1.save(mesh_path)
def srgb_to_linear(c_srgb):
c_linear = np.where(c_srgb <= 0.04045, c_srgb / 12.92, ((c_srgb + 0.055) / 1.055) ** 2.4)
return c_linear.clip(0, 1.)
def save_py3dmesh_with_trimesh_fast(meshes: Meshes, save_glb_path, apply_sRGB_to_LinearRGB=True):
# convert from pytorch3d meshes to trimesh mesh
vertices = meshes.verts_packed().cpu().float().numpy()
triangles = meshes.faces_packed().cpu().long().numpy()
np_color = meshes.textures.verts_features_packed().cpu().float().numpy()
if save_glb_path.endswith(".glb"):
# rotate 180 along +Y
vertices[:, [0, 2]] = -vertices[:, [0, 2]]
if apply_sRGB_to_LinearRGB:
np_color = srgb_to_linear(np_color)
assert vertices.shape[0] == np_color.shape[0]
assert np_color.shape[1] == 3
assert 0 <= np_color.min() and np_color.max() <= 1, f"min={np_color.min()}, max={np_color.max()}"
mesh = trimesh.Trimesh(vertices=vertices, faces=triangles, vertex_colors=np_color)
mesh.remove_unreferenced_vertices()
# save mesh
mesh.export(save_glb_path)
if save_glb_path.endswith(".glb"):
fix_vert_color_glb(save_glb_path)
print(f"saving to {save_glb_path}")
def load_mesh_with_trimesh(file_name, file_type=None):
mesh: trimesh.Trimesh = trimesh.load(file_name, file_type=file_type)
if isinstance(mesh, trimesh.Scene):
assert len(mesh.geometry) > 0
# save to obj first and load again to avoid offset issue
from io import BytesIO
with BytesIO() as f:
mesh.export(f, file_type="obj")
f.seek(0)
mesh = trimesh.load(f, file_type="obj")
if isinstance(mesh, trimesh.Scene):
# we lose texture information here
mesh = trimesh.util.concatenate(
tuple(trimesh.Trimesh(vertices=g.vertices, faces=g.faces)
for g in mesh.geometry.values()))
assert isinstance(mesh, trimesh.Trimesh)
vertices = torch.from_numpy(mesh.vertices).T
faces = torch.from_numpy(mesh.faces).T
colors = None
if mesh.visual is not None:
if hasattr(mesh.visual, 'vertex_colors'):
colors = torch.from_numpy(mesh.visual.vertex_colors)[..., :3].T / 255.
if colors is None:
# print("Warning: no vertex color found in mesh! Filling it with gray.")
colors = torch.ones_like(vertices) * 0.5
return vertices, faces, colors
def meshlab_mesh_to_py3dmesh(mesh: ml.Mesh) -> Meshes:
verts = torch.from_numpy(mesh.vertex_matrix()).float()
faces = torch.from_numpy(mesh.face_matrix()).long()
colors = torch.from_numpy(mesh.vertex_color_matrix()[..., :3]).float()
textures = TexturesVertex(verts_features=[colors])
return Meshes(verts=[verts], faces=[faces], textures=textures)
def py3dmesh_to_meshlab_mesh(meshes: Meshes) -> ml.Mesh:
colors_in = F.pad(meshes.textures.verts_features_packed().cpu().float(), [0,1], value=1).numpy().astype(np.float64)
m1 = ml.Mesh(
vertex_matrix=meshes.verts_packed().cpu().float().numpy().astype(np.float64),
face_matrix=meshes.faces_packed().cpu().long().numpy().astype(np.int32),
v_normals_matrix=meshes.verts_normals_packed().cpu().float().numpy().astype(np.float64),
v_color_matrix=colors_in)
return m1
def to_pyml_mesh(vertices,faces):
m1 = ml.Mesh(
vertex_matrix=vertices.cpu().float().numpy().astype(np.float64),
face_matrix=faces.cpu().long().numpy().astype(np.int32),
)
return m1
def to_py3d_mesh(vertices, faces, normals=None):
from pytorch3d.structures import Meshes
from pytorch3d.renderer.mesh.textures import TexturesVertex
mesh = Meshes(verts=[vertices], faces=[faces], textures=None)
if normals is None:
normals = mesh.verts_normals_packed()
# set normals as vertext colors
mesh.textures = TexturesVertex(verts_features=[normals / 2 + 0.5])
return mesh
def from_py3d_mesh(mesh):
return mesh.verts_list()[0], mesh.faces_list()[0], mesh.textures.verts_features_packed()
def simple_clean_mesh(pyml_mesh: ml.Mesh, apply_smooth=True, stepsmoothnum=1, apply_sub_divide=False, sub_divide_threshold=0.25):
ms = ml.MeshSet()
ms.add_mesh(pyml_mesh, "cube_mesh")
if apply_smooth:
ms.apply_filter("apply_coord_laplacian_smoothing", stepsmoothnum=stepsmoothnum, cotangentweight=False)
if apply_sub_divide: # 5s, slow
ms.apply_filter("meshing_repair_non_manifold_vertices")
ms.apply_filter("meshing_repair_non_manifold_edges", method='Remove Faces')
ms.apply_filter("meshing_surface_subdivision_loop", iterations=2, threshold=PercentageValue(sub_divide_threshold))
return meshlab_mesh_to_py3dmesh(ms.current_mesh())
def post_process_mesh(mesh, cluster_to_keep=1000):
"""
Post-process a mesh to filter out floaters and disconnected parts
"""
import copy
print("post processing the mesh to have {} clusterscluster_to_kep".format(cluster_to_keep))
mesh_0 = copy.deepcopy(mesh)
with o3d.utility.VerbosityContextManager(o3d.utility.VerbosityLevel.Debug) as cm:
triangle_clusters, cluster_n_triangles, cluster_area = (mesh_0.cluster_connected_triangles())
triangle_clusters = np.asarray(triangle_clusters)
cluster_n_triangles = np.asarray(cluster_n_triangles)
cluster_area = np.asarray(cluster_area)
n_cluster = np.sort(cluster_n_triangles.copy())[-cluster_to_keep]
n_cluster = max(n_cluster, 50) # filter meshes smaller than 50
triangles_to_remove = cluster_n_triangles[triangle_clusters] < n_cluster
mesh_0.remove_triangles_by_mask(triangles_to_remove)
mesh_0.remove_unreferenced_vertices()
mesh_0.remove_degenerate_triangles()
print("num vertices raw {}".format(len(mesh.vertices)))
print("num vertices post {}".format(len(mesh_0.vertices)))
return mesh_0