yslan's picture
init
7f51798
raw
history blame
16.7 kB
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
import json
import argparse
import numpy as np
import torch
import os
import random
import glob
from tqdm import tqdm
import kaolin as kal
import point_cloud_utils as pcu
import ipdb
import pandas as pd
import numpy as np
from functools import partial
from pdb import set_trace as st
from pathlib import Path
from functools import partial
# unused, already matched
# varyData = [
# ["X", 270],
# ["Z", 180],
# ]
def rotate_point_cloud(point_cloud, rotations):
"""
Rotates a point cloud along specified axes by the given angles.
:param point_cloud: Nx3 numpy array of points
:param rotations: list of tuples [(axis, angle_in_degrees), ...]
Example: [('x', 90), ('y', 45)] for composite rotations
:return: Rotated point cloud as Nx3 numpy array
"""
rotated_cloud = point_cloud.copy()
for axis, angle in rotations:
angle_rad = np.radians(angle) # Convert degrees to radians
R = rotation_matrix(axis, angle_rad)
rotated_cloud = np.dot(rotated_cloud, R.T) # Apply rotation matrix
return rotated_cloud
# transformation to align all results in the same canonical space
transformation_dict = {
'gso': partial(rotate_point_cloud, rotations=[('x', 0)]), # no transformation
'LGM_fixpose': partial(rotate_point_cloud, rotations=[('x', 90), ('z', 180)]),
'CRM/Animals': partial(rotate_point_cloud, rotations=[('x', 90), ('z', 180)]),
'Lara': partial(rotate_point_cloud, rotations=[('x', -110), ('z', 33)]),
'ln3diff-lite/Animals': partial(rotate_point_cloud, rotations=[('x', 90)]),
'One-2-3-45/Animals': partial(rotate_point_cloud, rotations=[('x', 90), ('z', 180)]),
'splatter-img': partial(rotate_point_cloud, rotations=[('x', -60)]),
#
'OpenLRM/Animals': partial(rotate_point_cloud, rotations=[('x', 0)]),
'shape-e/Animals': partial(rotate_point_cloud, rotations=[('x', 0)]),
#
'objv-gt': partial(rotate_point_cloud, rotations=[('x', 0)]),
'GA': partial(rotate_point_cloud, rotations=[('x', 0)]),
# un-aligned
'scale3d/eval/eval_nerf/Animals': partial(rotate_point_cloud, rotations=[('x', 0)]),
'scale3d/eval/eval_mesh/Animals': partial(rotate_point_cloud, rotations=[('x', 180), ('z', 180)]),
}
def VaryPoint(data, axis, degree):
# to rotate axis
xyzArray = {
'X': np.array([[1, 0, 0],
[0, cos(radians(degree)), -sin(radians(degree))],
[0, sin(radians(degree)), cos(radians(degree))]]),
'Y': np.array([[cos(radians(degree)), 0, sin(radians(degree))],
[0, 1, 0],
[-sin(radians(degree)), 0, cos(radians(degree))]]),
'Z': np.array([[cos(radians(degree)), -sin(radians(degree)), 0],
[sin(radians(degree)), cos(radians(degree)), 0],
[0, 0, 1]])}
newData = np.dot(data, xyzArray[axis])
return newData
from math import *
def seed_everything(seed):
if seed < 0:
return
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
def read_pcd(name, n_sample=2048):
v = pcu.load_mesh_v(name)
point_clouds = np.random.permutation(v)[:n_sample, :]
return torch.from_numpy(point_clouds).unsqueeze(0)
def get_score(results, use_same_numer_for_test=False):
if use_same_numer_for_test:
results = results[:, :results.shape[0]]
mmd = results.min(axis=1).mean()
min_ref = results.argmin(axis=0)
unique_idx = np.unique(min_ref)
cov = float(len(unique_idx)) / results.shape[0]
# if mmd < 1:
# # Chamfer distance
mmd = mmd * 1000 # for showing results
return mmd, cov * 100
def rotation_matrix(axis, angle):
"""
Returns a rotation matrix for a given axis and angle in radians.
:param axis: str, the axis to rotate around ('x', 'y', or 'z')
:param angle: float, the rotation angle in radians
:return: 3x3 rotation matrix
"""
if axis == 'x':
return np.array([[1, 0, 0],
[0, np.cos(angle), -np.sin(angle)],
[0, np.sin(angle), np.cos(angle)]])
elif axis == 'y':
return np.array([[np.cos(angle), 0, np.sin(angle)],
[0, 1, 0],
[-np.sin(angle), 0, np.cos(angle)]])
elif axis == 'z':
return np.array([[np.cos(angle), -np.sin(angle), 0],
[np.sin(angle), np.cos(angle), 0],
[0, 0, 1]])
else:
raise ValueError("Axis must be 'x', 'y', or 'z'.")
def scale_to_unit_sphere(points, center=None):
midpoints = (torch.max(points, axis=1)[0] + torch.min(points, axis=1)[0]) / 2
# midpoints = np.mean(points, axis=0)
points = points - midpoints
scale = torch.max(torch.sqrt(torch.sum(points ** 2, axis=2)))
points = points / scale
return points
def sample_point_with_mesh_name(method_name, name, n_sample=2048, normalized_scale=1.0, rotate_degree=-90):
#ipdb.set_trace()
# if '.ply' in name:
# v = pcu.load_mesh_v(name)
# point_clouds = np.random.permutation(v)[:n_sample, :]
# scale = point_clouds.max()-point_clouds.min()
# point_clouds = point_clouds / scale #* normalized_scale # Make them in the same scale pcu.save_mesh_v('a.obj',point_clouds)
# #ipdb.set_trace()
# return torch.from_numpy(point_clouds).float().cuda().unsqueeze(dim=0)
try:
mesh_1 = kal.io.obj.import_mesh(name)
if mesh_1.vertices.shape[0] == 0:
return None
vertices = mesh_1.vertices.cuda()
#ipdb.set_trace()
#scale = (vertices.max(dim=0)[0] - vertices.min(dim=0)[0]).max()
mesh_v1 = vertices #/ scale #* normalized_scale pcu.save_mesh_v('a.ply',points[0].cpu().numpy())
mesh_f1 = mesh_1.faces.cuda()
points, _ = kal.ops.mesh.sample_points(mesh_v1.unsqueeze(dim=0), mesh_f1, n_sample)
#ipdb.set_trace()
points=scale_to_unit_sphere(points)
#ipdb.set_trace()
return points.cuda()
except:
v = pcu.load_mesh_v(name)
point_clouds = np.random.permutation(v)[:n_sample, :]
#scale = point_clouds.max()-point_clouds.min()
#point_clouds = point_clouds / scale #* normalized_scale # Make them in the same scale pcu.save_mesh_v('a.obj',point_clouds)
#ipdb.set_trace()
point_clouds=torch.from_numpy(point_clouds).float().cuda().unsqueeze(dim=0)
point_clouds=scale_to_unit_sphere(point_clouds)
# point_clouds=point_clouds*-1
# rand rotate
# rand_rot = [('x', random.randint(0,359)), ('y', random.randint(0,359)), ('z', random.randint(0,359))]
# rand_transform = partial(rotate_point_cloud, rotations=rand_rot) # no transformation
# point_clouds = rand_transform(point_clouds[0].cpu().numpy()) # since no canonical space
# point_clouds = torch.from_numpy(point_clouds).float().cuda().unsqueeze(dim=0)
# ipdb.set_trace()
pcd_transform = transformation_dict[method_name] # to the same canonical space
point_clouds = pcd_transform(point_clouds[0].cpu().numpy()) # since no canonical space
point_clouds = torch.from_numpy(point_clouds).float().cuda().unsqueeze(dim=0)
def VaryPoint(data, axis, degree):
xyzArray = {
'X': np.array([[1, 0, 0],
[0, cos(radians(degree)), -sin(radians(degree))],
[0, sin(radians(degree)), cos(radians(degree))]]),
'Y': np.array([[cos(radians(degree)), 0, sin(radians(degree))],
[0, 1, 0],
[-sin(radians(degree)), 0, cos(radians(degree))]]),
'Z': np.array([[cos(radians(degree)), -sin(radians(degree)), 0],
[sin(radians(degree)), cos(radians(degree)), 0],
[0, 0, 1]])}
newData = np.dot(data, xyzArray[axis])
return newData
# if rorate_minus_90:
# varyData = [
# # ["X", rotate_degree], # stl file -90
# ]
# else:
varyData = [
["X", 0], # stl file -90
]
for para in varyData:
point_clouds_new = VaryPoint(point_clouds[0,:, :3].cpu().numpy(), para[0], para[1])
# ipdb.set_trace()
return torch.Tensor(point_clouds_new).cuda().unsqueeze(0)
#print('error')
def chamfer_distance(method_name,ref_name,ref_pcs, sample_pcs, batch_size,save_name):
all_rec_pcs = []
n_sample = 2048
normalized_scale = 1.0
# ipdb.set_trace()
if os.path.exists(os.path.join(save_name,'gt.pth')):
# if False:
all_rec_pcs=torch.load(os.path.join(save_name,'gt.pth')).to('cuda')
else:
# if True:
for name in tqdm(ref_pcs):
# all_rec_pcs.append(sample_point_with_mesh_name(name, n_sample, normalized_scale=normalized_scale, rotate_degree=0))
all_rec_pcs.append(sample_point_with_mesh_name(ref_name, name, n_sample, normalized_scale=normalized_scale, rotate_degree=0))
# all_rec_pcs.append(read_pcd(name, n_sample))
all_rec_pcs = [p for p in all_rec_pcs if p is not None]
all_rec_pcs = torch.cat(all_rec_pcs, dim=0).to('cuda')
#ipdb.set_trace()
os.makedirs(os.path.join(save_name), exist_ok=True)
torch.save(all_rec_pcs,os.path.join(save_name,'gt.pth'))
# methodname=sample_pcs[0].split('/')[-2]
#ipdb.set_trace()
# if os.path.exists(os.path.join(save_name,'sample.pth')):
if False:
all_sample_pcs=torch.load(os.path.join(save_name,'sample.pth')).to('cuda')
else:
# if True:
all_sample_pcs = []
for name in tqdm(sample_pcs):
# This is generated
#ipdb.set_trace()
# all_sample_pcs.append(sample_point_with_mesh_name(name, n_sample, normalized_scale=normalized_scale, rotate_degree=90)) # all_sample_pcs.append(read_pcd(name, n_sample))
all_sample_pcs.append(sample_point_with_mesh_name(method_name, name, n_sample, normalized_scale=normalized_scale, rotate_degree=0)) # all_sample_pcs.append(read_pcd(name, n_sample))
# ipdb.set_trace()
pass
all_sample_pcs = [p for p in all_sample_pcs if p is not None]
all_sample_pcs = torch.cat(all_sample_pcs, dim=0).to('cuda')
os.makedirs(os.path.join(save_name), exist_ok=True)
torch.save(all_sample_pcs,os.path.join(save_name,'sample.pth'))
# ipdb.set_trace()
#all_sample_pcs+=(all_rec_pcs.mean(0).mean(0)-all_sample_pcs.mean(0).mean(0))
# all_rec_pcs-=all_rec_pcs.mean(1).unsqueeze(1)
# all_sample_pcs-=all_sample_pcs.mean(1).unsqueeze(1)
#ipdb.set_trace()
# for para in varyData:
# for i in range(len(all_sample_pcs)):
# #ipdb.set_trace()
# all_sample_pcs[i] = torch.Tensor(VaryPoint(all_sample_pcs[i,:, :3].cpu().numpy(), para[0], para[1])).cuda()
# all_sample_pcs+=(all_rec_pcs.mean(0).mean(0)-all_sample_pcs.mean(0).mean(0))
# all_sample_pcs.mean(0).mean(0)
# all_sample_pcs[:,1,:]-=0.1
#all_sample_pcs=all_sample_pcs[:3684]
#all_rec_pcs=all_rec_pcs[:1000] all_sample_pcs[...,2]*=-1 pcu.save_mesh_v('a.ply',all_rec_pcs[8].cpu().numpy()) pcu.save_mesh_v('b.ply',all_sample_pcs[1391].reshape(-1,3).cpu().numpy())
print('datapreparation')
#ipdb.set_trace()
all_cd = []
for i_ref_p in tqdm(range(len(all_rec_pcs))):
ref_p = all_rec_pcs[i_ref_p]
cd_lst = []
for sample_b_start in range(0, len(sample_pcs), batch_size):
sample_b_end = min(len(sample_pcs), sample_b_start + batch_size)
sample_batch = all_sample_pcs[sample_b_start:sample_b_end]
batch_size_sample = sample_batch.size(0)
chamfer = kal.metrics.pointcloud.chamfer_distance(
ref_p.unsqueeze(dim=0).expand(batch_size_sample, -1, -1),
sample_batch)
cd_lst.append(chamfer)
cd_lst = torch.cat(cd_lst, dim=0)
all_cd.append(cd_lst.unsqueeze(dim=0))
all_cd = torch.cat(all_cd, dim=0)
return all_cd
def compute_all_metrics(method_name,ref_name,sample_pcs, ref_pcs, batch_size, save_name=None):
results = chamfer_distance(method_name,ref_name,ref_pcs, sample_pcs, batch_size,save_name).data.cpu().numpy()
#ipdb.set_trace()
#results = results[:, :results.shape[0] * 5] # Generation is 5 time of the testing set
cd_mmd, cd_cov = get_score(results, use_same_numer_for_test=False)
#ipdb.set_trace()
print('cov,mmd:',(cd_cov, cd_mmd, save_name))
def evaluate(args):
# Set the random seed
# seed_everything(0) # for GA default
seed_everything(42)
ref_path=[]
# shapenet_cls = args.dataset_path.split('/')[-1]
# if shapenet_cls == 'chair':
# train_lst=np.loadtxt(f'/mnt/cache/yslan/get3d/{shapenet_cls}_train_list_srn.txt','str')
# else:
# train_lst=np.loadtxt(f'/mnt/cache/yslan/get3d/{shapenet_cls}_train_list.txt','str')
# for s in os.listdir(args.dataset_path):
# ipdb.set_trace()
# ipdb.set_trace()
# if 'GA' in args.dataset_path or 'objv-gt' in args.dataset_path:
gen_path_base='/mnt/sfs-common/yslan/Repo/3dgen/FID-KID-Outputdir-objv/3D-metrics-fps'
objv_dataset = '/mnt/sfs-common/yslan/Dataset/Obajverse/chunk-jpeg-normal/bs_16_fixsave3/170K/512/'
dataset_json = os.path.join(objv_dataset, 'dataset.json')
with open(dataset_json, 'r') as f:
dataset_json = json.load(f)
all_objs = dataset_json['Animals'][::3][1100:2200][:600] # pick top 600 instances.
ref_path = [os.path.join(args.dataset_path, f"{obj.replace('/', '-')}_pcd_4096.ply") for obj in all_objs]
# ipdb.set_trace()
# else:
# ref_path = sorted(glob.glob(f'{args.dataset_path}/*.ply') )
# for s in files:
# if os.path.exists(os.path.join(args.dataset_path, s, 'pcd_4096.ply')):
# ref_path = ref_path+[os.path.join(args.dataset_path, s, 'pcd_4096.ply')]
# for s in os.listdir(args.dataset_path):
# #ipdb.set_trace()
# # if s=='toy_boat':
# if os.path.isdir(os.path.join(args.dataset_path, s)):
# for instance in os.listdir(os.path.join(args.dataset_path, s)):
# if os.path.exists(os.path.join(args.dataset_path, s,instance,'Scan','Scan.obj')):
# ref_path = ref_path+[os.path.join(args.dataset_path, s, instance,'Scan','Scan.obj') ]
gen_path = args.gen_path
# method_name = '/'.join(gen_path.split('/')[-2:])
method_name = str(Path(gen_path).relative_to(gen_path_base))
ref_name = str(Path(args.dataset_path).relative_to(gen_path_base))
#ref_path=ref_path[::100]
gen_models = glob.glob(os.path.join(gen_path, '*.ply'))
gen_models = sorted(gen_models)
# if '_cond' in gen_path:
# args.save_name+='_cond'
gen_models = gen_models[:args.n_shape]
with torch.no_grad():
#ipdb.set_trace()
compute_all_metrics(method_name,ref_name,gen_models, ref_path, args.batch_size, args.save_name)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--save_name", type=str, default='/mnt/petrelfs/caoziang/3D_generation/cmetric/get3d/omni_final_surface', help="path to the save results")
parser.add_argument("--dataset_path", type=str,default='/mnt/petrelfs/share_data/wutong/DATA/OO3D/ply_files/4096', help="path to the original shapenet dataset")
parser.add_argument("--gen_path", type=str, default='/mnt/petrelfs/caoziang/3D_generation/Checkpoint_all/diffusion_shapenet_testmodel11/ddpm_5/test',help="path to the generated models")
parser.add_argument("--n_points", type=int, default=2048, help="Number of points used for evaluation")
parser.add_argument("--batch_size", type=int, default=100, help="batch size to compute chamfer distance")
parser.add_argument("--n_shape", type=int, default=7500, help="number of shapes for evaluations")
parser.add_argument("--use_npz", type=bool, default=False, help="whether the generated shape is npz or not")
args = parser.parse_args()
evaluate(args)