Spaces:
Runtime error
Runtime error
import numpy as np | |
from pathlib import Path | |
from PIL import Image | |
import json | |
import torch | |
from torch.utils.data import Dataset, DataLoader, default_collate | |
from torchvision.transforms import ToTensor, Normalize, Compose, Resize | |
from pytorch_lightning import LightningDataModule | |
from einops import rearrange | |
class LatentObjaverseSpiral(Dataset): | |
def __init__( | |
self, | |
root_dir, | |
split="train", | |
transform=None, | |
random_front=False, | |
max_item=None, | |
cond_aug_mean=-3.0, | |
cond_aug_std=0.5, | |
condition_on_elevation=False, | |
**unused_kwargs, | |
): | |
print("Using LVIS subset with precomputed Latents") | |
self.root_dir = Path(root_dir) | |
self.split = split | |
self.random_front = random_front | |
self.transform = transform | |
self.latent_dir = Path("/mnt/vepfs/3Ddataset/render_results/latents512") | |
self.ids = json.load(open("./assets/lvis_uids.json", "r")) | |
self.n_views = 18 | |
valid_ids = [] | |
for idx in self.ids: | |
if (self.root_dir / idx).exists(): | |
valid_ids.append(idx) | |
self.ids = valid_ids | |
print("=" * 30) | |
print("Number of valid ids: ", len(self.ids)) | |
print("=" * 30) | |
self.cond_aug_mean = cond_aug_mean | |
self.cond_aug_std = cond_aug_std | |
self.condition_on_elevation = condition_on_elevation | |
if max_item is not None: | |
self.ids = self.ids[:max_item] | |
## debug | |
self.ids = self.ids * 10000 | |