Spaces:
Running
on
Zero
Running
on
Zero
from abc import ABC, abstractmethod | |
from multiprocessing.pool import ThreadPool | |
from typing import List, Optional, Tuple, Union | |
from tqdm import tqdm | |
from pdb import set_trace as st | |
import numpy as np | |
import torch | |
from point_e.models.download import load_checkpoint | |
from npz_stream import NpzStreamer | |
from pointnet2_cls_ssg import get_model | |
def get_torch_devices() -> List[Union[str, torch.device]]: | |
if torch.cuda.is_available(): | |
return [torch.device(f"cuda:{i}") for i in range(torch.cuda.device_count())] | |
else: | |
return ["cpu"] | |
class FeatureExtractor(ABC): | |
def supports_predictions(self) -> bool: | |
pass | |
def feature_dim(self) -> int: | |
pass | |
def num_classes(self) -> int: | |
pass | |
def features_and_preds(self, streamer: NpzStreamer) -> Tuple[np.ndarray, np.ndarray]: | |
""" | |
For a stream of point cloud batches, compute feature vectors and class | |
predictions. | |
:param point_clouds: a streamer for a sample batch. Typically, arr_0 | |
will contain the XYZ coordinates. | |
:return: a tuple (features, predictions) | |
- features: a [B x feature_dim] array of feature vectors. | |
- predictions: a [B x num_classes] array of probabilities. | |
""" | |
class PointNetClassifier(FeatureExtractor): | |
def __init__( | |
self, | |
devices: List[Union[str, torch.device]], | |
device_batch_size: int = 64, | |
cache_dir: Optional[str] = None, | |
): | |
state_dict = load_checkpoint("pointnet", device=torch.device("cpu"), cache_dir=cache_dir)[ | |
"model_state_dict" | |
] | |
self.device_batch_size = device_batch_size | |
self.devices = devices | |
# self.models = [] | |
# for device in devices: | |
model = get_model(num_class=40, normal_channel=False, width_mult=2) | |
model.load_state_dict(state_dict) | |
model.to('cuda') | |
model.eval() | |
# self.models.append(model) | |
self.model = model | |
def supports_predictions(self) -> bool: | |
return True | |
def feature_dim(self) -> int: | |
return 256 | |
def num_classes(self) -> int: | |
return 40 | |
# def features_and_preds(self, streamer: NpzStreamer) -> Tuple[np.ndarray, np.ndarray]: | |
def features_and_preds(self, streamer) -> Tuple[np.ndarray, np.ndarray]: | |
# batch_size = self.device_batch_size * len(self.devices) | |
# batch_size = self.device_batch_size * len(self.devices) | |
point_clouds = streamer # switch to pytorch stream here | |
# point_clouds = (x["arr_0"] for x in streamer.stream(batch_size, ["arr_0"])) | |
device = 'cuda' | |
output_features = [] | |
output_predictions = [] | |
# st() | |
# with ThreadPool(len(self.devices)) as pool: | |
for _, batch in enumerate(tqdm(point_clouds)): # type: ignore | |
# batch = normalize_point_clouds(batch) | |
# batches = [] | |
# for i, device in zip(range(0, len(batch), self.device_batch_size), self.devices): | |
# batches.append( | |
# batch = torch.from_numpy(batch).permute(0, 2, 1).to(dtype=torch.float32, device=device) | |
batch = batch.to(dtype=torch.float32, device=device).permute(0, 2, 1) # B 3 L | |
def compute_features(batch): | |
# batch = i_batch | |
with torch.no_grad(): | |
return self.model(batch, features=True) | |
# for logits, _, features in pool.imap(compute_features, enumerate(batches)): | |
# for logits, _, features in pool.imap(compute_features, enumerate(batches)): | |
logits, _, features = compute_features(batch) | |
output_features.append(features.cpu().numpy()) | |
output_predictions.append(logits.exp().cpu().numpy()) | |
return np.concatenate(output_features, axis=0), np.concatenate(output_predictions, axis=0) | |
def normalize_point_clouds(pc: np.ndarray) -> np.ndarray: | |
centroids = np.mean(pc, axis=1, keepdims=True) | |
pc = pc - centroids | |
m = np.max(np.sqrt(np.sum(pc**2, axis=-1, keepdims=True)), axis=1, keepdims=True) | |
pc = pc / m | |
return pc | |