yslan's picture
init
7f51798
raw
history blame
4.33 kB
from abc import ABC, abstractmethod
from multiprocessing.pool import ThreadPool
from typing import List, Optional, Tuple, Union
from tqdm import tqdm
from pdb import set_trace as st
import numpy as np
import torch
from point_e.models.download import load_checkpoint
from npz_stream import NpzStreamer
from pointnet2_cls_ssg import get_model
def get_torch_devices() -> List[Union[str, torch.device]]:
if torch.cuda.is_available():
return [torch.device(f"cuda:{i}") for i in range(torch.cuda.device_count())]
else:
return ["cpu"]
class FeatureExtractor(ABC):
@property
@abstractmethod
def supports_predictions(self) -> bool:
pass
@property
@abstractmethod
def feature_dim(self) -> int:
pass
@property
@abstractmethod
def num_classes(self) -> int:
pass
@abstractmethod
def features_and_preds(self, streamer: NpzStreamer) -> Tuple[np.ndarray, np.ndarray]:
"""
For a stream of point cloud batches, compute feature vectors and class
predictions.
:param point_clouds: a streamer for a sample batch. Typically, arr_0
will contain the XYZ coordinates.
:return: a tuple (features, predictions)
- features: a [B x feature_dim] array of feature vectors.
- predictions: a [B x num_classes] array of probabilities.
"""
class PointNetClassifier(FeatureExtractor):
def __init__(
self,
devices: List[Union[str, torch.device]],
device_batch_size: int = 64,
cache_dir: Optional[str] = None,
):
state_dict = load_checkpoint("pointnet", device=torch.device("cpu"), cache_dir=cache_dir)[
"model_state_dict"
]
self.device_batch_size = device_batch_size
self.devices = devices
# self.models = []
# for device in devices:
model = get_model(num_class=40, normal_channel=False, width_mult=2)
model.load_state_dict(state_dict)
model.to('cuda')
model.eval()
# self.models.append(model)
self.model = model
@property
def supports_predictions(self) -> bool:
return True
@property
def feature_dim(self) -> int:
return 256
@property
def num_classes(self) -> int:
return 40
# def features_and_preds(self, streamer: NpzStreamer) -> Tuple[np.ndarray, np.ndarray]:
def features_and_preds(self, streamer) -> Tuple[np.ndarray, np.ndarray]:
# batch_size = self.device_batch_size * len(self.devices)
# batch_size = self.device_batch_size * len(self.devices)
point_clouds = streamer # switch to pytorch stream here
# point_clouds = (x["arr_0"] for x in streamer.stream(batch_size, ["arr_0"]))
device = 'cuda'
output_features = []
output_predictions = []
# st()
# with ThreadPool(len(self.devices)) as pool:
for _, batch in enumerate(tqdm(point_clouds)): # type: ignore
# batch = normalize_point_clouds(batch)
# batches = []
# for i, device in zip(range(0, len(batch), self.device_batch_size), self.devices):
# batches.append(
# batch = torch.from_numpy(batch).permute(0, 2, 1).to(dtype=torch.float32, device=device)
batch = batch.to(dtype=torch.float32, device=device).permute(0, 2, 1) # B 3 L
def compute_features(batch):
# batch = i_batch
with torch.no_grad():
return self.model(batch, features=True)
# for logits, _, features in pool.imap(compute_features, enumerate(batches)):
# for logits, _, features in pool.imap(compute_features, enumerate(batches)):
logits, _, features = compute_features(batch)
output_features.append(features.cpu().numpy())
output_predictions.append(logits.exp().cpu().numpy())
return np.concatenate(output_features, axis=0), np.concatenate(output_predictions, axis=0)
def normalize_point_clouds(pc: np.ndarray) -> np.ndarray:
centroids = np.mean(pc, axis=1, keepdims=True)
pc = pc - centroids
m = np.max(np.sqrt(np.sum(pc**2, axis=-1, keepdims=True)), axis=1, keepdims=True)
pc = pc / m
return pc