|
import math |
|
import os |
|
import json |
|
from dataclasses import dataclass, field |
|
|
|
import random |
|
import imageio |
|
import numpy as np |
|
import pytorch_lightning as pl |
|
import torch |
|
import torch.nn.functional as F |
|
from torch.utils.data import DataLoader, Dataset |
|
from torchvision import transforms |
|
from PIL import Image |
|
from transformers import CLIPImageProcessor, CLIPTokenizer |
|
|
|
from craftsman import register |
|
from craftsman.utils.base import Updateable |
|
from craftsman.utils.config import parse_structured |
|
from craftsman.utils.typing import * |
|
|
|
def rot2eul(R): |
|
beta = -np.arcsin(R[2,0]) |
|
alpha = np.arctan2(R[2,1]/np.cos(beta),R[2,2]/np.cos(beta)) |
|
gamma = np.arctan2(R[1,0]/np.cos(beta),R[0,0]/np.cos(beta)) |
|
return np.array((alpha, beta, gamma)) |
|
|
|
def eul2rot(theta) : |
|
R = np.array([[np.cos(theta[1])*np.cos(theta[2]), np.sin(theta[0])*np.sin(theta[1])*np.cos(theta[2]) - np.sin(theta[2])*np.cos(theta[0]), np.sin(theta[1])*np.cos(theta[0])*np.cos(theta[2]) + np.sin(theta[0])*np.sin(theta[2])], |
|
[np.sin(theta[2])*np.cos(theta[1]), np.sin(theta[0])*np.sin(theta[1])*np.sin(theta[2]) + np.cos(theta[0])*np.cos(theta[2]), np.sin(theta[1])*np.sin(theta[2])*np.cos(theta[0]) - np.sin(theta[0])*np.cos(theta[2])], |
|
[-np.sin(theta[1]), np.sin(theta[0])*np.cos(theta[1]), np.cos(theta[0])*np.cos(theta[1])]]) |
|
return R |
|
|
|
@dataclass |
|
class ObjaverseDataModuleConfig: |
|
root_dir: str = None |
|
data_type: str = "occupancy" |
|
n_samples: int = 4096 |
|
scale: float = 1.0 |
|
noise_sigma: float = 0.0 |
|
|
|
load_supervision: bool = True |
|
supervision_type: str = "occupancy" |
|
n_supervision: int = 10000 |
|
|
|
load_image: bool = False |
|
image_data_path: str = "" |
|
image_type: str = "rgb" |
|
background_color: Tuple[float, float, float] = field( |
|
default_factory=lambda: (1.0, 1.0, 1.0) |
|
) |
|
idx: Optional[List[int]] = None |
|
n_views: int = 1 |
|
rotate_points: bool = False |
|
|
|
load_caption: bool = False |
|
caption_type: str = "text" |
|
tokenizer_pretrained_model_name_or_path: str = "" |
|
|
|
batch_size: int = 32 |
|
num_workers: int = 0 |
|
|
|
class ObjaverseDataset(Dataset): |
|
def __init__(self, cfg: Any, split: str) -> None: |
|
super().__init__() |
|
self.cfg: ObjaverseDataModuleConfig = cfg |
|
self.split = split |
|
|
|
self.uids = json.load(open(f'{cfg.root_dir}/{split}.json')) |
|
print(f"Loaded {len(self.uids)} {split} uids") |
|
|
|
if self.cfg.load_caption: |
|
self.tokenizer = CLIPTokenizer.from_pretrained(self.cfg.tokenizer_pretrained_model_name_or_path) |
|
|
|
self.background_color = torch.as_tensor(self.cfg.background_color) |
|
self.distance = 1.0 |
|
self.camera_embedding = torch.as_tensor([ |
|
[[1, 0, 0, 0], |
|
[0, 0, -1, -self.distance], |
|
[0, 1, 0, 0], |
|
[0, 0, 0, 1]], |
|
|
|
[[0, 0, 1, self.distance], |
|
[1, 0, 0, 0], |
|
[0, 1, 0, 0], |
|
[0, 0, 0, 1]], |
|
|
|
[[-1, 0, 0, 0], |
|
[0, 0, 1, self.distance], |
|
[0, 1, 0, 0], |
|
[0, 0, 0, 1]], |
|
|
|
[[0, 0, -1, -self.distance], |
|
[-1, 0, 0, 0], |
|
[0, 1, 0, 0], |
|
[0, 0, 0, 1]], |
|
], dtype=torch.float32) |
|
if self.cfg.n_views != 1: |
|
assert self.cfg.n_views == self.camera_embedding.shape[0] |
|
|
|
def __len__(self): |
|
return len(self.uids) |
|
|
|
def _load_shape(self, index: int): |
|
if self.cfg.data_type == "occupancy": |
|
|
|
pointcloud = np.load(f'{self.cfg.root_dir}/{self.uids[index]}/pointcloud.npz') |
|
surface = np.asarray(pointcloud['points']) * 2 |
|
normal = np.asarray(pointcloud['normals']) |
|
surface = np.concatenate([surface, normal], axis=1) |
|
elif self.cfg.data_type == "sdf": |
|
data = np.load(f'{self.cfg.root_dir}/{self.uids[index]}.npz') |
|
|
|
surface = data["surface"] |
|
else: |
|
raise NotImplementedError(f"Data type {self.cfg.data_type} not implemented") |
|
|
|
|
|
rng = np.random.default_rng() |
|
ind = rng.choice(surface.shape[0], self.cfg.n_samples, replace=False) |
|
surface = surface[ind] |
|
|
|
surface[:, :3] = surface[:, :3] * self.cfg.scale |
|
|
|
surface[:, :3] += (np.random.rand(surface.shape[0], 3) * 2 - 1) * self.cfg.noise_sigma |
|
ret = { |
|
"uid": self.uids[index].split('/')[-1], |
|
"surface": surface.astype(np.float32), |
|
} |
|
|
|
return ret |
|
|
|
def _load_shape_supervision(self, index: int): |
|
|
|
ret = {} |
|
if self.cfg.data_type == "occupancy": |
|
points = np.load(f'{self.cfg.root_dir}/{self.uids[index]}/points.npz') |
|
rand_points = np.asarray(points['points']) * 2 |
|
occupancies = np.asarray(points['occupancies']) |
|
occupancies = np.unpackbits(occupancies) |
|
elif self.cfg.data_type == "sdf": |
|
data = np.load(f'{self.cfg.root_dir}/{self.uids[index]}.npz') |
|
rand_points = data['rand_points'] |
|
sdfs = data['sdfs'] |
|
else: |
|
raise NotImplementedError(f"Data type {self.cfg.data_type} not implemented") |
|
|
|
|
|
rng = np.random.default_rng() |
|
ind = rng.choice(rand_points.shape[0], self.cfg.n_supervision, replace=False) |
|
rand_points = rand_points[ind] |
|
rand_points = rand_points * self.cfg.scale |
|
ret["rand_points"] = rand_points.astype(np.float32) |
|
|
|
if self.cfg.data_type == "occupancy": |
|
assert self.cfg.supervision_type == "occupancy", "Only occupancy supervision is supported for occupancy data" |
|
occupancies = occupancies[ind] |
|
ret["occupancies"] = occupancies.astype(np.float32) |
|
elif self.cfg.data_type == "sdf": |
|
if self.cfg.supervision_type == "sdf": |
|
ret["sdf"] = sdfs[ind].flatten().astype(np.float32) |
|
elif self.cfg.supervision_type == "occupancy": |
|
ret["occupancies"] = np.where(sdfs[ind].flatten() < 1e-3, 0, 1).astype(np.float32) |
|
else: |
|
raise NotImplementedError(f"Supervision type {self.cfg.supervision_type} not implemented") |
|
|
|
return ret |
|
|
|
def _load_image(self, index: int): |
|
def _load_single_image(img_path): |
|
img = torch.from_numpy( |
|
np.asarray( |
|
Image.fromarray(imageio.v2.imread(img_path)) |
|
.convert("RGBA") |
|
) |
|
/ 255.0 |
|
).float() |
|
mask: Float[Tensor, "H W 1"] = img[:, :, -1:] |
|
image: Float[Tensor, "H W 3"] = img[:, :, :3] * mask + self.background_color[ |
|
None, None, : |
|
] * (1 - mask) |
|
return image |
|
|
|
ret = {} |
|
if self.cfg.image_type == "rgb" or self.cfg.image_type == "normal": |
|
assert self.cfg.n_views == 1, "Only single view is supported for single image" |
|
sel_idx = random.choice(self.cfg.idx) |
|
ret["sel_image_idx"] = sel_idx |
|
if self.cfg.image_type == "rgb": |
|
img_path = f'{self.cfg.image_data_path}/' + "/".join(self.uids[index].split('/')[-2:]) + f"/{sel_idx}.png" |
|
elif self.cfg.image_type == "normal": |
|
img_path = f'{self.cfg.image_data_path}/' + "/".join(self.uids[index].split('/')[-2:]) + f"/{sel_idx}_normal.png" |
|
ret["image"] = _load_single_image(img_path) |
|
ret["c2w"] = self.camera_embedding[sel_idx % 4] |
|
elif self.cfg.image_type == "mvrgb" or self.cfg.image_type == "mvnormal": |
|
sel_idx = random.choice(self.cfg.idx) |
|
ret["sel_image_idx"] = sel_idx |
|
mvimages = [] |
|
for i in range(self.cfg.n_views): |
|
if self.cfg.image_type == "mvrgb": |
|
img_path = f'{self.cfg.image_data_path}/' + "/".join(self.uids[index].split('/')[-2:]) + f"/{sel_idx+i}.png" |
|
elif self.cfg.image_type == "mvnormal": |
|
img_path = f'{self.cfg.image_data_path}/' + "/".join(self.uids[index].split('/')[-2:]) + f"/{sel_idx+i}_normal.png" |
|
mvimages.append(_load_single_image(img_path)) |
|
ret["mvimages"] = torch.stack(mvimages) |
|
ret["c2ws"] = self.camera_embedding |
|
else: |
|
raise NotImplementedError(f"Image type {self.cfg.image_type} not implemented") |
|
|
|
return ret |
|
|
|
def _load_caption(self, index: int, drop_text_embed: bool = False): |
|
ret = {} |
|
if self.cfg.caption_type == "text": |
|
caption = eval(json.load(open(f'{self.cfg.image_data_path}/' + "/".join(self.uids[index].split('/')[-2:]) + f'/annotation.json'))) |
|
texts = [v for k, v in caption.items()] |
|
sel_idx = random.randint(0, len(texts) - 1) |
|
ret["sel_caption_idx"] = sel_idx |
|
ret['text_input_ids'] = self.tokenizer( |
|
texts[sel_idx] if not drop_text_embed else "", |
|
max_length=self.tokenizer.model_max_length, |
|
padding="max_length", |
|
truncation=True, |
|
return_tensors="pt" |
|
).input_ids.detach() |
|
else: |
|
raise NotImplementedError(f"Caption type {self.cfg.caption_type} not implemented") |
|
|
|
return ret |
|
|
|
def get_data(self, index): |
|
|
|
ret = self._load_shape(index) |
|
|
|
|
|
if self.cfg.load_supervision: |
|
ret.update(self._load_shape_supervision(index)) |
|
|
|
|
|
if self.cfg.load_image: |
|
ret.update(self._load_image(index)) |
|
|
|
|
|
rots = np.load(f'{self.cfg.image_data_path}/' + "/".join(self.uids[index].split('/')[-2:]) + f'/rots.npy')[ret['sel_image_idx']].astype(np.float32) |
|
rots = torch.tensor(rots[:3, :3], dtype=torch.float32) |
|
if "c2ws" in ret.keys(): |
|
ret["c2ws"][:, :3, :3] = torch.matmul(rots, ret["c2ws"][:, :3, :3]) |
|
ret["c2ws"][:, :3, 3] = torch.matmul(rots, ret["c2ws"][:, :3, 3].unsqueeze(-1)).squeeze(-1) |
|
elif "c2w" in ret.keys(): |
|
ret["c2w"][:3, :3] = torch.matmul(rots, ret["c2w"][:3, :3]) |
|
ret["c2w"][:3, 3] = torch.matmul(rots, ret["c2w"][:3, 3].unsqueeze(-1)).squeeze(-1) |
|
|
|
|
|
if self.cfg.load_caption: |
|
ret.update(self._load_caption(index)) |
|
|
|
return ret |
|
|
|
def __getitem__(self, index): |
|
try: |
|
return self.get_data(index) |
|
except Exception as e: |
|
print(f"Error in {self.uids[index]}: {e}") |
|
return self.__getitem__(np.random.randint(len(self))) |
|
|
|
|
|
def collate(self, batch): |
|
batch = torch.utils.data.default_collate(batch) |
|
return batch |
|
|
|
|
|
@register("objaverse-datamodule") |
|
class ObjaverseDataModule(pl.LightningDataModule): |
|
cfg: ObjaverseDataModuleConfig |
|
|
|
def __init__(self, cfg: Optional[Union[dict, DictConfig]] = None) -> None: |
|
super().__init__() |
|
self.cfg = parse_structured(ObjaverseDataModuleConfig, cfg) |
|
|
|
def setup(self, stage=None) -> None: |
|
if stage in [None, "fit"]: |
|
self.train_dataset = ObjaverseDataset(self.cfg, "train") |
|
if stage in [None, "fit", "validate"]: |
|
self.val_dataset = ObjaverseDataset(self.cfg, "val") |
|
if stage in [None, "test", "predict"]: |
|
self.test_dataset = ObjaverseDataset(self.cfg, "test") |
|
|
|
def prepare_data(self): |
|
pass |
|
|
|
def general_loader(self, dataset, batch_size, collate_fn=None, num_workers=0) -> DataLoader: |
|
return DataLoader( |
|
dataset, batch_size=batch_size, collate_fn=collate_fn, num_workers=num_workers |
|
) |
|
|
|
def train_dataloader(self) -> DataLoader: |
|
return self.general_loader( |
|
self.train_dataset, |
|
batch_size=self.cfg.batch_size, |
|
collate_fn=self.train_dataset.collate, |
|
num_workers=self.cfg.num_workers |
|
) |
|
|
|
def val_dataloader(self) -> DataLoader: |
|
return self.general_loader(self.val_dataset, batch_size=1) |
|
|
|
def test_dataloader(self) -> DataLoader: |
|
return self.general_loader(self.test_dataset, batch_size=1) |
|
|
|
def predict_dataloader(self) -> DataLoader: |
|
return self.general_loader(self.test_dataset, batch_size=1) |