V3D / sgm /data /latent_objaverse.py
heheyas
init
cfb7702
raw
history blame
1.55 kB
import numpy as np
from pathlib import Path
from PIL import Image
import json
import torch
from torch.utils.data import Dataset, DataLoader, default_collate
from torchvision.transforms import ToTensor, Normalize, Compose, Resize
from pytorch_lightning import LightningDataModule
from einops import rearrange
class LatentObjaverseSpiral(Dataset):
def __init__(
self,
root_dir,
split="train",
transform=None,
random_front=False,
max_item=None,
cond_aug_mean=-3.0,
cond_aug_std=0.5,
condition_on_elevation=False,
**unused_kwargs,
):
print("Using LVIS subset with precomputed Latents")
self.root_dir = Path(root_dir)
self.split = split
self.random_front = random_front
self.transform = transform
self.latent_dir = Path("/mnt/vepfs/3Ddataset/render_results/latents512")
self.ids = json.load(open("./assets/lvis_uids.json", "r"))
self.n_views = 18
valid_ids = []
for idx in self.ids:
if (self.root_dir / idx).exists():
valid_ids.append(idx)
self.ids = valid_ids
print("=" * 30)
print("Number of valid ids: ", len(self.ids))
print("=" * 30)
self.cond_aug_mean = cond_aug_mean
self.cond_aug_std = cond_aug_std
self.condition_on_elevation = condition_on_elevation
if max_item is not None:
self.ids = self.ids[:max_item]
## debug
self.ids = self.ids * 10000