yslan's picture
init
7f51798
raw
history blame
13.1 kB
# https://raw.githubusercontent.com/3dlg-hcvc/omages/refs/heads/main/src/evals/fpd_eval.py
import os
import random
from tqdm import tqdm
import glob
from pdb import set_trace as st
import trimesh
import sys
import numpy as np
import scipy # should be version 1.11.1
import torch
import argparse
# from point_e.evals.feature_extractor import PointNetClassifier, get_torch_devices
from feature_extractor import PointNetClassifier, get_torch_devices
from point_e.evals.fid_is import compute_statistics
from point_e.evals.fid_is import compute_inception_score
from point_e.evals.npz_stream import NpzStreamer
import numpy as np
def rotation_matrix(axis, angle):
"""
Returns a rotation matrix for a given axis and angle in radians.
:param axis: str, the axis to rotate around ('x', 'y', or 'z')
:param angle: float, the rotation angle in radians
:return: 3x3 rotation matrix
"""
if axis == 'x':
return np.array([[1, 0, 0],
[0, np.cos(angle), -np.sin(angle)],
[0, np.sin(angle), np.cos(angle)]])
elif axis == 'y':
return np.array([[np.cos(angle), 0, np.sin(angle)],
[0, 1, 0],
[-np.sin(angle), 0, np.cos(angle)]])
elif axis == 'z':
return np.array([[np.cos(angle), -np.sin(angle), 0],
[np.sin(angle), np.cos(angle), 0],
[0, 0, 1]])
else:
raise ValueError("Axis must be 'x', 'y', or 'z'.")
def rotate_point_cloud(point_cloud, rotations):
"""
Rotates a point cloud along specified axes by the given angles.
:param point_cloud: Nx3 numpy array of points
:param rotations: list of tuples [(axis, angle_in_degrees), ...]
Example: [('x', 90), ('y', 45)] for composite rotations
:return: Rotated point cloud as Nx3 numpy array
"""
rotated_cloud = point_cloud.copy()
for axis, angle in rotations:
angle_rad = np.radians(angle) # Convert degrees to radians
R = rotation_matrix(axis, angle_rad)
rotated_cloud = np.dot(rotated_cloud, R.T) # Apply rotation matrix
return rotated_cloud
from functools import partial
# transformation dictionary
transformation_dict = {
'gso': partial(rotate_point_cloud, rotations=[('x', 0)]), # no transformation
'LGM': partial(rotate_point_cloud, rotations=[('x', 90)]),
'CRM': partial(rotate_point_cloud, rotations=[('x', 90), ('z', 180)]),
'Lara': partial(rotate_point_cloud, rotations=[('x', -110), ('z', 33)]),
'ln3diff': partial(rotate_point_cloud, rotations=[('x', 90)]),
'One-2-3-45': partial(rotate_point_cloud, rotations=[('x', 90), ('z', 180)]),
'splatter-img': partial(rotate_point_cloud, rotations=[('x', -60)]),
#
'OpenLRM': partial(rotate_point_cloud, rotations=[('x', 0)]),
'shape-e': partial(rotate_point_cloud, rotations=[('x', 0)]),
# un-aligned
'ditl-fromditlPCD-fixPose-tomesh': partial(rotate_point_cloud, rotations=[('x', 0)]),
'ditl-fromditlPCD-fixPose-tomesh-ditxlPCD': partial(rotate_point_cloud, rotations=[('x', 0)]),
}
class PFID_evaluator():
def __init__(self, devices=['cuda:0'], batch_size=256, cache_dir='~/.temp/PFID_evaluator'):
self.__dict__.update(locals())
cache_dir = os.path.expanduser(cache_dir)
if not os.path.exists(cache_dir):
os.makedirs(cache_dir)
self.devices = [torch.device(d) for d in devices]
self.clf = PointNetClassifier(devices=self.devices, cache_dir=cache_dir, device_batch_size=self.batch_size)
def compute_pfid(self, pc_1, pc_2, return_feature=False):
# print("computing first batch activations")
# save clouds to npz files
npz_path1 = os.path.join(self.cache_dir, "temp1.npz")
npz_path2 = os.path.join(self.cache_dir, "temp2.npz")
np.savez(npz_path1, arr_0=pc_1)
np.savez(npz_path2, arr_0=pc_2)
features_1, _ = self.clf.features_and_preds(NpzStreamer(npz_path1))
stats_1 = compute_statistics(features_1)
# print(features_1.max(), features_1.min(), features_1.mean(), features_1.std() )
# print(stats_1.mu.shape, stats_1.sigma.shape)
features_2, _ = self.clf.features_and_preds(NpzStreamer(npz_path2))
stats_2 = compute_statistics(features_2)
# print(features_2.max(), features_2.min(), features_2.mean(), features_2.std() )
# print(stats_2.mu.shape, stats_2.sigma.shape)
if return_feature:
return features_1, features_2
#PFID = stats_1.frechet_distance(stats_2) # same result as the next line
PFID= frechet_distance(stats_1.mu, stats_1.sigma, stats_2.mu, stats_2.sigma)
PKID = kernel_distance(features_1, features_2)
print(f"P-FID: {PFID}", f"P-KID: {PKID}")
return dict(PFID=PFID, PKID=PKID)
# from https://github.com/GaParmar/clean-fid/blob/main/cleanfid/fid.py
"""
Numpy implementation of the Frechet Distance.
The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
and X_2 ~ N(mu_2, C_2) is
d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
Stable version by Danica J. Sutherland.
Params:
mu1 : Numpy array containing the activations of a layer of the
inception net (like returned by the function 'get_predictions')
for generated samples.
mu2 : The sample mean over activations, precalculated on an
representative data set.
sigma1: The covariance matrix over activations for generated samples.
sigma2: The covariance matrix over activations, precalculated on an
representative data set.
"""
def frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
mu1 = np.atleast_1d(mu1)
mu2 = np.atleast_1d(mu2)
sigma1 = np.atleast_2d(sigma1)
sigma2 = np.atleast_2d(sigma2)
assert mu1.shape == mu2.shape, \
'Training and test mean vectors have different lengths'
assert sigma1.shape == sigma2.shape, \
'Training and test covariances have different dimensions'
diff = mu1 - mu2
# Product might be almost singular
covmean, _ = scipy.linalg.sqrtm(sigma1.dot(sigma2), disp=False)
if not np.isfinite(covmean).all():
msg = ('fid calculation produces singular product; '
'adding %s to diagonal of cov estimates') % eps
print(msg)
offset = np.eye(sigma1.shape[0]) * eps
covmean = scipy.linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
# Numerical error might give slight imaginary component
if np.iscomplexobj(covmean):
if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
m = np.max(np.abs(covmean.imag))
raise ValueError('Imaginary component {}'.format(m))
covmean = covmean.real
tr_covmean = np.trace(covmean)
return (diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean)
"""
Compute the KID score given the sets of features
"""
def kernel_distance(feats1, feats2, num_subsets=100, max_subset_size=1000):
n = feats1.shape[1]
m = min(min(feats1.shape[0], feats2.shape[0]), max_subset_size)
t = 0
for _subset_idx in range(num_subsets):
x = feats2[np.random.choice(feats2.shape[0], m, replace=False)]
y = feats1[np.random.choice(feats1.shape[0], m, replace=False)]
a = (x @ x.T / n + 1) ** 3 + (y @ y.T / n + 1) ** 3
b = (x @ y.T / n + 1) ** 3
t += (a.sum() - np.diag(a).sum()) / (m - 1) - b.sum() * 2 / m
kid = t / num_subsets / m
return float(kid)
# load and calculate fid, kid, is
def normalize_point_clouds(pc: np.ndarray) -> np.ndarray:
# centroids = np.mean(pc, axis=1, keepdims=True)
centroids = np.mean(pc, axis=1, keepdims=True)
pc = pc - centroids
m = np.max(np.sqrt(np.sum(pc**2, axis=-1, keepdims=True)), axis=1, keepdims=True)
pc = pc / m
return pc
class PCDPathDataset(torch.utils.data.Dataset):
def __init__(self, pcd_file_path, transformation, rand_aug=False):
files = sorted(glob.glob(f'{pcd_file_path}/*.ply') )
# assert len(files)==1030 # gso
self.files = files
self.transformation = transformation
# self.transforms = transforms
# self.reso=reso
self.rand_aug = rand_aug
# if rand_aug:
# else:
# self.rand_transform = None
def __len__(self):
return len(self.files)
def __getitem__(self, i):
path = self.files[i]
pcd = trimesh.load(path).vertices # pcu may fail sometimes
pcd = normalize_point_clouds(pcd[None])[0]
pcd = self.transformation(pcd)
if self.rand_aug is not None:
rand_rot = [('x', random.randint(0,359)), ('y', random.randint(0,359)), ('z', random.randint(0,359))]
rand_transform = partial(rotate_point_cloud, rotations=rand_rot) # no transformation
pcd = rand_transform(pcd) # since no canonical space
# try:
# assert pcd.shape[1]==4096
# except Exception as e:
# print(path)
return pcd
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--cache_dir", type=str, default=None)
parser.add_argument("batch_1", type=str)
parser.add_argument("batch_2", type=str)
args = parser.parse_args()
print("creating classifier...")
clf = PointNetClassifier(devices=get_torch_devices(), cache_dir=args.cache_dir)
worker=2
# force_recompute = False
force_recompute = True
feat_1_path = os.path.join(args.batch_1, 'feat.npy')
pred_1_path = os.path.join(args.batch_1, 'pred.npy')
# if not force_recompute and all(os.path.exists(path) for path in [feat_1_path, pred_1_path]):
# if all(os.path.exists(path) for path in [feat_1_path, pred_1_path]):
if not force_recompute and all(os.path.exists(path) for path in [feat_1_path, pred_1_path]):
print("loading activations", args.batch_1)
features_1 = np.load(feat_1_path)
preds_1 = np.load(pred_1_path)
else:
print("computing activations", args.batch_1)
# gt_dataset = PCDPathDataset(args.batch_1, transformation_dict['gso'])
gt_dataset = PCDPathDataset(args.batch_1, transformation_dict['gso'], rand_aug=True)
# gt
gt_loader = torch.utils.data.DataLoader(gt_dataset,
batch_size=64,
shuffle=False,
drop_last=False,
num_workers=worker)
features_1, preds_1 = clf.features_and_preds(gt_loader)
np.save(feat_1_path, features_1)
np.save(pred_1_path, preds_1)
feat_2_path = os.path.join(args.batch_2, 'feat.npy')
pred_2_path = os.path.join(args.batch_2, 'pred.npy')
if not force_recompute and all(os.path.exists(path) for path in [feat_2_path, pred_2_path]):
features_2 = np.load(feat_2_path)
preds_2 = np.load(pred_2_path)
print("loading activations", args.batch_2)
else:
print("computing activations", args.batch_2)
method_name = args.batch_2.split('/')[-1]
# st()
pcd_transformation = transformation_dict[method_name]
pred_dataset = PCDPathDataset(args.batch_2, transformation=pcd_transformation, rand_aug=True)
# worker=0
pred_loader = torch.utils.data.DataLoader(pred_dataset,
batch_size=64,
shuffle=False,
drop_last=False,
num_workers=worker)
features_2, preds_2 = clf.features_and_preds(pred_loader)
np.save(feat_2_path, features_2)
np.save(feat_2_path, preds_2)
print("computing statistics")
stats_1 = compute_statistics(features_1)
# print(features_1.max(), features_1.min(), features_1.mean(), features_1.std() )
# print(stats_1.mu.shape, stats_1.sigma.shape)
stats_2 = compute_statistics(features_2)
# print(features_2.max(), features_2.min(), features_2.mean(), features_2.std() )
# print(stats_2.mu.shape, stats_2.sigma.shape)
# if return_feature:
# return features_1, features_2
#PFID = stats_1.frechet_distance(stats_2) # same result as the next line
PFID= frechet_distance(stats_1.mu, stats_1.sigma, stats_2.mu, stats_2.sigma)
PKID = kernel_distance(features_1, features_2)
# _, preds = clf.features_and_preds(pred_loader)
# print(f"P-IS: {compute_inception_score(preds)}")
# print(f"P-IS: {compute_inception_score(preds)}")
method_name = args.batch_2.split('/')[-1]
# print(method_name, f"P-FID: {PFID}", f"P-KID: {PKID}", f"P-IS: {compute_inception_score(preds_2)}")
print(method_name, f"P-FID: {PFID}", f"P-KID: {PKID}")
# return dict(PFID=PFID, PKID=PKID)
if __name__ == "__main__":
main()