Spaces:
Sleeping
Sleeping
import pytorch_lightning as pl | |
import numpy as np | |
import torch | |
import PIL | |
import os | |
from skimage.io import imread | |
import webdataset as wds | |
import PIL.Image as Image | |
from torch.utils.data import Dataset | |
from torch.utils.data.distributed import DistributedSampler | |
from pathlib import Path | |
from ldm.base_utils import read_pickle, pose_inverse | |
import torchvision.transforms as transforms | |
import torchvision | |
from einops import rearrange | |
from ldm.util import prepare_inputs | |
class SyncDreamerTrainData(Dataset): | |
def __init__(self, target_dir, input_dir, uid_set_pkl, image_size=256): | |
self.default_image_size = 256 | |
self.image_size = image_size | |
self.target_dir = Path(target_dir) | |
self.input_dir = Path(input_dir) | |
self.uids = read_pickle(uid_set_pkl) | |
print('============= length of dataset %d =============' % len(self.uids)) | |
image_transforms = [] | |
image_transforms.extend([transforms.ToTensor(), transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))]) | |
self.image_transforms = torchvision.transforms.Compose(image_transforms) | |
self.num_images = 16 | |
def __len__(self): | |
return len(self.uids) | |
def load_im(self, path): | |
img = imread(path) | |
img = img.astype(np.float32) / 255.0 | |
mask = img[:,:,3:] | |
img[:,:,:3] = img[:,:,:3] * mask + 1 - mask # white background | |
img = Image.fromarray(np.uint8(img[:, :, :3] * 255.)) | |
return img, mask | |
def process_im(self, im): | |
im = im.convert("RGB") | |
im = im.resize((self.image_size, self.image_size), resample=PIL.Image.BICUBIC) | |
return self.image_transforms(im) | |
def load_index(self, filename, index): | |
img, _ = self.load_im(os.path.join(filename, '%03d.png' % index)) | |
img = self.process_im(img) | |
return img | |
def get_data_for_index(self, index): | |
target_dir = os.path.join(self.target_dir, self.uids[index]) | |
input_dir = os.path.join(self.input_dir, self.uids[index]) | |
views = np.arange(0, self.num_images) | |
start_view_index = np.random.randint(0, self.num_images) | |
views = (views + start_view_index) % self.num_images | |
target_images = [] | |
for si, target_index in enumerate(views): | |
img = self.load_index(target_dir, target_index) | |
target_images.append(img) | |
target_images = torch.stack(target_images, 0) | |
input_img = self.load_index(input_dir, start_view_index) | |
K, azimuths, elevations, distances, cam_poses = read_pickle(os.path.join(input_dir, f'meta.pkl')) | |
input_elevation = torch.from_numpy(elevations[start_view_index:start_view_index+1].astype(np.float32)) | |
return {"target_image": target_images, "input_image": input_img, "input_elevation": input_elevation} | |
def __getitem__(self, index): | |
data = self.get_data_for_index(index) | |
return data | |
class SyncDreamerEvalData(Dataset): | |
def __init__(self, image_dir): | |
self.image_size = 256 | |
self.image_dir = Path(image_dir) | |
self.crop_size = 20 | |
self.fns = [] | |
for fn in Path(image_dir).iterdir(): | |
if fn.suffix=='.png': | |
self.fns.append(fn) | |
print('============= length of dataset %d =============' % len(self.fns)) | |
def __len__(self): | |
return len(self.fns) | |
def get_data_for_index(self, index): | |
input_img_fn = self.fns[index] | |
elevation = int(Path(input_img_fn).stem.split('-')[-1]) | |
return prepare_inputs(input_img_fn, elevation, 200) | |
def __getitem__(self, index): | |
return self.get_data_for_index(index) | |
class SyncDreamerDataset(pl.LightningDataModule): | |
def __init__(self, target_dir, input_dir, validation_dir, batch_size, uid_set_pkl, image_size=256, num_workers=4, seed=0, **kwargs): | |
super().__init__() | |
self.target_dir = target_dir | |
self.input_dir = input_dir | |
self.validation_dir = validation_dir | |
self.batch_size = batch_size | |
self.num_workers = num_workers | |
self.uid_set_pkl = uid_set_pkl | |
self.seed = seed | |
self.additional_args = kwargs | |
self.image_size = image_size | |
def setup(self, stage): | |
if stage in ['fit']: | |
self.train_dataset = SyncDreamerTrainData(self.target_dir, self.input_dir, uid_set_pkl=self.uid_set_pkl, image_size=256) | |
self.val_dataset = SyncDreamerEvalData(image_dir=self.validation_dir) | |
else: | |
raise NotImplementedError | |
def train_dataloader(self): | |
sampler = DistributedSampler(self.train_dataset, seed=self.seed) | |
return wds.WebLoader(self.train_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, sampler=sampler) | |
def val_dataloader(self): | |
loader = wds.WebLoader(self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False) | |
return loader | |
def test_dataloader(self): | |
return wds.WebLoader(self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False) | |