Spaces:
Sleeping
Sleeping
File size: 5,093 Bytes
fe3e74d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
import pytorch_lightning as pl
import numpy as np
import torch
import PIL
import os
from skimage.io import imread
import webdataset as wds
import PIL.Image as Image
from torch.utils.data import Dataset
from torch.utils.data.distributed import DistributedSampler
from pathlib import Path
from ldm.base_utils import read_pickle, pose_inverse
import torchvision.transforms as transforms
import torchvision
from einops import rearrange
from ldm.util import prepare_inputs
class SyncDreamerTrainData(Dataset):
def __init__(self, target_dir, input_dir, uid_set_pkl, image_size=256):
self.default_image_size = 256
self.image_size = image_size
self.target_dir = Path(target_dir)
self.input_dir = Path(input_dir)
self.uids = read_pickle(uid_set_pkl)
print('============= length of dataset %d =============' % len(self.uids))
image_transforms = []
image_transforms.extend([transforms.ToTensor(), transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))])
self.image_transforms = torchvision.transforms.Compose(image_transforms)
self.num_images = 16
def __len__(self):
return len(self.uids)
def load_im(self, path):
img = imread(path)
img = img.astype(np.float32) / 255.0
mask = img[:,:,3:]
img[:,:,:3] = img[:,:,:3] * mask + 1 - mask # white background
img = Image.fromarray(np.uint8(img[:, :, :3] * 255.))
return img, mask
def process_im(self, im):
im = im.convert("RGB")
im = im.resize((self.image_size, self.image_size), resample=PIL.Image.BICUBIC)
return self.image_transforms(im)
def load_index(self, filename, index):
img, _ = self.load_im(os.path.join(filename, '%03d.png' % index))
img = self.process_im(img)
return img
def get_data_for_index(self, index):
target_dir = os.path.join(self.target_dir, self.uids[index])
input_dir = os.path.join(self.input_dir, self.uids[index])
views = np.arange(0, self.num_images)
start_view_index = np.random.randint(0, self.num_images)
views = (views + start_view_index) % self.num_images
target_images = []
for si, target_index in enumerate(views):
img = self.load_index(target_dir, target_index)
target_images.append(img)
target_images = torch.stack(target_images, 0)
input_img = self.load_index(input_dir, start_view_index)
K, azimuths, elevations, distances, cam_poses = read_pickle(os.path.join(input_dir, f'meta.pkl'))
input_elevation = torch.from_numpy(elevations[start_view_index:start_view_index+1].astype(np.float32))
return {"target_image": target_images, "input_image": input_img, "input_elevation": input_elevation}
def __getitem__(self, index):
data = self.get_data_for_index(index)
return data
class SyncDreamerEvalData(Dataset):
def __init__(self, image_dir):
self.image_size = 256
self.image_dir = Path(image_dir)
self.crop_size = 20
self.fns = []
for fn in Path(image_dir).iterdir():
if fn.suffix=='.png':
self.fns.append(fn)
print('============= length of dataset %d =============' % len(self.fns))
def __len__(self):
return len(self.fns)
def get_data_for_index(self, index):
input_img_fn = self.fns[index]
elevation = int(Path(input_img_fn).stem.split('-')[-1])
return prepare_inputs(input_img_fn, elevation, 200)
def __getitem__(self, index):
return self.get_data_for_index(index)
class SyncDreamerDataset(pl.LightningDataModule):
def __init__(self, target_dir, input_dir, validation_dir, batch_size, uid_set_pkl, image_size=256, num_workers=4, seed=0, **kwargs):
super().__init__()
self.target_dir = target_dir
self.input_dir = input_dir
self.validation_dir = validation_dir
self.batch_size = batch_size
self.num_workers = num_workers
self.uid_set_pkl = uid_set_pkl
self.seed = seed
self.additional_args = kwargs
self.image_size = image_size
def setup(self, stage):
if stage in ['fit']:
self.train_dataset = SyncDreamerTrainData(self.target_dir, self.input_dir, uid_set_pkl=self.uid_set_pkl, image_size=256)
self.val_dataset = SyncDreamerEvalData(image_dir=self.validation_dir)
else:
raise NotImplementedError
def train_dataloader(self):
sampler = DistributedSampler(self.train_dataset, seed=self.seed)
return wds.WebLoader(self.train_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, sampler=sampler)
def val_dataloader(self):
loader = wds.WebLoader(self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)
return loader
def test_dataloader(self):
return wds.WebLoader(self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)
|