heheyas
init
cfb7702
raw
history blame
7.31 kB
import os
import json
import math
import numpy as np
from PIL import Image
import cv2
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, IterableDataset
import torchvision.transforms.functional as TF
import pytorch_lightning as pl
import datasets
from models.ray_utils import get_ray_directions
from utils.misc import get_rank
def load_K_Rt_from_P(P=None):
out = cv2.decomposeProjectionMatrix(P)
K = out[0]
R = out[1]
t = out[2]
K = K / K[2, 2]
intrinsics = np.eye(4)
intrinsics[:3, :3] = K
pose = np.eye(4, dtype=np.float32)
pose[:3, :3] = R.transpose()
pose[:3, 3] = (t[:3] / t[3])[:, 0]
return intrinsics, pose
def create_spheric_poses(cameras, n_steps=120):
center = torch.as_tensor([0.,0.,0.], dtype=cameras.dtype, device=cameras.device)
cam_center = F.normalize(cameras.mean(0), p=2, dim=-1) * cameras.mean(0).norm(2)
eigvecs = torch.linalg.eig(cameras.T @ cameras).eigenvectors
rot_axis = F.normalize(eigvecs[:,1].real.float(), p=2, dim=-1)
up = rot_axis
rot_dir = torch.cross(rot_axis, cam_center)
max_angle = (F.normalize(cameras, p=2, dim=-1) * F.normalize(cam_center, p=2, dim=-1)).sum(-1).acos().max()
all_c2w = []
for theta in torch.linspace(-max_angle, max_angle, n_steps):
cam_pos = cam_center * math.cos(theta) + rot_dir * math.sin(theta)
l = F.normalize(center - cam_pos, p=2, dim=0)
s = F.normalize(l.cross(up), p=2, dim=0)
u = F.normalize(s.cross(l), p=2, dim=0)
c2w = torch.cat([torch.stack([s, u, -l], dim=1), cam_pos[:,None]], axis=1)
all_c2w.append(c2w)
all_c2w = torch.stack(all_c2w, dim=0)
return all_c2w
class DTUDatasetBase():
def setup(self, config, split):
self.config = config
self.split = split
self.rank = get_rank()
cams = np.load(os.path.join(self.config.root_dir, self.config.cameras_file))
img_sample = cv2.imread(os.path.join(self.config.root_dir, 'image', '000000.png'))
H, W = img_sample.shape[0], img_sample.shape[1]
if 'img_wh' in self.config:
w, h = self.config.img_wh
assert round(W / w * h) == H
elif 'img_downscale' in self.config:
w, h = int(W / self.config.img_downscale + 0.5), int(H / self.config.img_downscale + 0.5)
else:
raise KeyError("Either img_wh or img_downscale should be specified.")
self.w, self.h = w, h
self.img_wh = (w, h)
self.factor = w / W
mask_dir = os.path.join(self.config.root_dir, 'mask')
self.has_mask = True
self.apply_mask = self.config.apply_mask
self.directions = []
self.all_c2w, self.all_images, self.all_fg_masks = [], [], []
n_images = max([int(k.split('_')[-1]) for k in cams.keys()]) + 1
for i in range(n_images):
world_mat, scale_mat = cams[f'world_mat_{i}'], cams[f'scale_mat_{i}']
P = (world_mat @ scale_mat)[:3,:4]
K, c2w = load_K_Rt_from_P(P)
fx, fy, cx, cy = K[0,0] * self.factor, K[1,1] * self.factor, K[0,2] * self.factor, K[1,2] * self.factor
directions = get_ray_directions(w, h, fx, fy, cx, cy)
self.directions.append(directions)
c2w = torch.from_numpy(c2w).float()
# blender follows opengl camera coordinates (right up back)
# NeuS DTU data coordinate system (right down front) is different from blender
# https://github.com/Totoro97/NeuS/issues/9
# for c2w, flip the sign of input camera coordinate yz
c2w_ = c2w.clone()
c2w_[:3,1:3] *= -1. # flip input sign
self.all_c2w.append(c2w_[:3,:4])
if self.split in ['train', 'val']:
img_path = os.path.join(self.config.root_dir, 'image', f'{i:06d}.png')
img = Image.open(img_path)
img = img.resize(self.img_wh, Image.BICUBIC)
img = TF.to_tensor(img).permute(1, 2, 0)[...,:3]
mask_path = os.path.join(mask_dir, f'{i:03d}.png')
mask = Image.open(mask_path).convert('L') # (H, W, 1)
mask = mask.resize(self.img_wh, Image.BICUBIC)
mask = TF.to_tensor(mask)[0]
self.all_fg_masks.append(mask) # (h, w)
self.all_images.append(img)
self.all_c2w = torch.stack(self.all_c2w, dim=0)
if self.split == 'test':
self.all_c2w = create_spheric_poses(self.all_c2w[:,:,3], n_steps=self.config.n_test_traj_steps)
self.all_images = torch.zeros((self.config.n_test_traj_steps, self.h, self.w, 3), dtype=torch.float32)
self.all_fg_masks = torch.zeros((self.config.n_test_traj_steps, self.h, self.w), dtype=torch.float32)
self.directions = self.directions[0]
else:
self.all_images, self.all_fg_masks = torch.stack(self.all_images, dim=0), torch.stack(self.all_fg_masks, dim=0)
self.directions = torch.stack(self.directions, dim=0)
self.directions = self.directions.float().to(self.rank)
self.all_c2w, self.all_images, self.all_fg_masks = \
self.all_c2w.float().to(self.rank), \
self.all_images.float().to(self.rank), \
self.all_fg_masks.float().to(self.rank)
class DTUDataset(Dataset, DTUDatasetBase):
def __init__(self, config, split):
self.setup(config, split)
def __len__(self):
return len(self.all_images)
def __getitem__(self, index):
return {
'index': index
}
class DTUIterableDataset(IterableDataset, DTUDatasetBase):
def __init__(self, config, split):
self.setup(config, split)
def __iter__(self):
while True:
yield {}
@datasets.register('dtu')
class DTUDataModule(pl.LightningDataModule):
def __init__(self, config):
super().__init__()
self.config = config
def setup(self, stage=None):
if stage in [None, 'fit']:
self.train_dataset = DTUIterableDataset(self.config, 'train')
if stage in [None, 'fit', 'validate']:
self.val_dataset = DTUDataset(self.config, self.config.get('val_split', 'train'))
if stage in [None, 'test']:
self.test_dataset = DTUDataset(self.config, self.config.get('test_split', 'test'))
if stage in [None, 'predict']:
self.predict_dataset = DTUDataset(self.config, 'train')
def prepare_data(self):
pass
def general_loader(self, dataset, batch_size):
sampler = None
return DataLoader(
dataset,
num_workers=os.cpu_count(),
batch_size=batch_size,
pin_memory=True,
sampler=sampler
)
def train_dataloader(self):
return self.general_loader(self.train_dataset, batch_size=1)
def val_dataloader(self):
return self.general_loader(self.val_dataset, batch_size=1)
def test_dataloader(self):
return self.general_loader(self.test_dataset, batch_size=1)
def predict_dataloader(self):
return self.general_loader(self.predict_dataset, batch_size=1)