import pytorch_lightning as pl import numpy as np import torch import PIL import os from skimage.io import imread import webdataset as wds import PIL.Image as Image from torch.utils.data import Dataset from torch.utils.data.distributed import DistributedSampler from pathlib import Path from ldm.base_utils import read_pickle, pose_inverse import torchvision.transforms as transforms import torchvision from einops import rearrange from ldm.util import prepare_inputs class SyncDreamerTrainData(Dataset): def __init__(self, target_dir, input_dir, uid_set_pkl, image_size=256): self.default_image_size = 256 self.image_size = image_size self.target_dir = Path(target_dir) self.input_dir = Path(input_dir) self.uids = read_pickle(uid_set_pkl) print('============= length of dataset %d =============' % len(self.uids)) image_transforms = [] image_transforms.extend([transforms.ToTensor(), transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))]) self.image_transforms = torchvision.transforms.Compose(image_transforms) self.num_images = 16 def __len__(self): return len(self.uids) def load_im(self, path): img = imread(path) img = img.astype(np.float32) / 255.0 mask = img[:,:,3:] img[:,:,:3] = img[:,:,:3] * mask + 1 - mask # white background img = Image.fromarray(np.uint8(img[:, :, :3] * 255.)) return img, mask def process_im(self, im): im = im.convert("RGB") im = im.resize((self.image_size, self.image_size), resample=PIL.Image.BICUBIC) return self.image_transforms(im) def load_index(self, filename, index): img, _ = self.load_im(os.path.join(filename, '%03d.png' % index)) img = self.process_im(img) return img def get_data_for_index(self, index): target_dir = os.path.join(self.target_dir, self.uids[index]) input_dir = os.path.join(self.input_dir, self.uids[index]) views = np.arange(0, self.num_images) start_view_index = np.random.randint(0, self.num_images) views = (views + start_view_index) % self.num_images target_images = [] for si, target_index in enumerate(views): img = self.load_index(target_dir, target_index) target_images.append(img) target_images = torch.stack(target_images, 0) input_img = self.load_index(input_dir, start_view_index) K, azimuths, elevations, distances, cam_poses = read_pickle(os.path.join(input_dir, f'meta.pkl')) input_elevation = torch.from_numpy(elevations[start_view_index:start_view_index+1].astype(np.float32)) return {"target_image": target_images, "input_image": input_img, "input_elevation": input_elevation} def __getitem__(self, index): data = self.get_data_for_index(index) return data class SyncDreamerEvalData(Dataset): def __init__(self, image_dir): self.image_size = 256 self.image_dir = Path(image_dir) self.crop_size = 20 self.fns = [] for fn in Path(image_dir).iterdir(): if fn.suffix=='.png': self.fns.append(fn) print('============= length of dataset %d =============' % len(self.fns)) def __len__(self): return len(self.fns) def get_data_for_index(self, index): input_img_fn = self.fns[index] elevation = int(Path(input_img_fn).stem.split('-')[-1]) return prepare_inputs(input_img_fn, elevation, 200) def __getitem__(self, index): return self.get_data_for_index(index) class SyncDreamerDataset(pl.LightningDataModule): def __init__(self, target_dir, input_dir, validation_dir, batch_size, uid_set_pkl, image_size=256, num_workers=4, seed=0, **kwargs): super().__init__() self.target_dir = target_dir self.input_dir = input_dir self.validation_dir = validation_dir self.batch_size = batch_size self.num_workers = num_workers self.uid_set_pkl = uid_set_pkl self.seed = seed self.additional_args = kwargs self.image_size = image_size def setup(self, stage): if stage in ['fit']: self.train_dataset = SyncDreamerTrainData(self.target_dir, self.input_dir, uid_set_pkl=self.uid_set_pkl, image_size=256) self.val_dataset = SyncDreamerEvalData(image_dir=self.validation_dir) else: raise NotImplementedError def train_dataloader(self): sampler = DistributedSampler(self.train_dataset, seed=self.seed) return wds.WebLoader(self.train_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, sampler=sampler) def val_dataloader(self): loader = wds.WebLoader(self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False) return loader def test_dataloader(self): return wds.WebLoader(self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)