heheyas
init
cfb7702
raw
history blame
10.4 kB
import os
import json
import math
import numpy as np
from PIL import Image
import cv2
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, IterableDataset
import torchvision.transforms.functional as TF
import pytorch_lightning as pl
import datasets
from models.ray_utils import get_ortho_ray_directions_origins, get_ortho_rays, get_ray_directions
from utils.misc import get_rank
from glob import glob
import PIL.Image
def camNormal2worldNormal(rot_c2w, camNormal):
H,W,_ = camNormal.shape
normal_img = np.matmul(rot_c2w[None, :, :], camNormal.reshape(-1,3)[:, :, None]).reshape([H, W, 3])
return normal_img
def worldNormal2camNormal(rot_w2c, worldNormal):
H,W,_ = worldNormal.shape
normal_img = np.matmul(rot_w2c[None, :, :], worldNormal.reshape(-1,3)[:, :, None]).reshape([H, W, 3])
return normal_img
def trans_normal(normal, RT_w2c, RT_w2c_target):
normal_world = camNormal2worldNormal(np.linalg.inv(RT_w2c[:3,:3]), normal)
normal_target_cam = worldNormal2camNormal(RT_w2c_target[:3,:3], normal_world)
return normal_target_cam
def img2normal(img):
return (img/255.)*2-1
def normal2img(normal):
return np.uint8((normal*0.5+0.5)*255)
def norm_normalize(normal, dim=-1):
normal = normal/(np.linalg.norm(normal, axis=dim, keepdims=True)+1e-6)
return normal
def RT_opengl2opencv(RT):
# Build the coordinate transform matrix from world to computer vision camera
# R_world2cv = R_bcam2cv@R_world2bcam
# T_world2cv = R_bcam2cv@T_world2bcam
R = RT[:3, :3]
t = RT[:3, 3]
R_bcam2cv = np.asarray([[1, 0, 0], [0, -1, 0], [0, 0, -1]], np.float32)
R_world2cv = R_bcam2cv @ R
t_world2cv = R_bcam2cv @ t
RT = np.concatenate([R_world2cv,t_world2cv[:,None]],1)
return RT
def normal_opengl2opencv(normal):
H,W,C = np.shape(normal)
# normal_img = np.reshape(normal, (H*W,C))
R_bcam2cv = np.array([1, -1, -1], np.float32)
normal_cv = normal * R_bcam2cv[None, None, :]
print(np.shape(normal_cv))
return normal_cv
def inv_RT(RT):
RT_h = np.concatenate([RT, np.array([[0,0,0,1]])], axis=0)
RT_inv = np.linalg.inv(RT_h)
return RT_inv[:3, :]
def load_a_prediction(root_dir, test_object, imSize, view_types, load_color=False, cam_pose_dir=None,
normal_system='front', erode_mask=True, camera_type='ortho', cam_params=None):
all_images = []
all_normals = []
all_normals_world = []
all_masks = []
all_color_masks = []
all_poses = []
all_w2cs = []
directions = []
ray_origins = []
RT_front = np.loadtxt(glob(os.path.join(cam_pose_dir, '*_%s_RT.txt'%( 'front')))[0]) # world2cam matrix
RT_front_cv = RT_opengl2opencv(RT_front) # convert normal from opengl to opencv
for idx, view in enumerate(view_types):
print(os.path.join(root_dir,test_object))
normal_filepath = os.path.join(root_dir, test_object, 'normals_000_%s.png'%( view))
# Load key frame
if load_color: # use bgr
image =np.array(PIL.Image.open(normal_filepath.replace("normals", "rgb")).resize(imSize))[:, :, :3]
normal = np.array(PIL.Image.open(normal_filepath).resize(imSize))
mask = normal[:, :, 3]
normal = normal[:, :, :3]
color_mask = np.array(PIL.Image.open(os.path.join(root_dir,test_object, 'masked_colors/rgb_000_%s.png'%( view))).resize(imSize))[:, :, 3]
invalid_color_mask = color_mask < 255*0.5
threshold = np.ones_like(image[:, :, 0]) * 250
invalid_white_mask = (image[:, :, 0] > threshold) & (image[:, :, 1] > threshold) & (image[:, :, 2] > threshold)
invalid_color_mask_final = invalid_color_mask & invalid_white_mask
color_mask = (1 - invalid_color_mask_final) > 0
# if erode_mask:
# kernel = np.ones((3, 3), np.uint8)
# mask = cv2.erode(mask, kernel, iterations=1)
RT = np.loadtxt(os.path.join(cam_pose_dir, '000_%s_RT.txt'%( view))) # world2cam matrix
normal = img2normal(normal)
normal[mask==0] = [0,0,0]
mask = mask> (0.5*255)
if load_color:
all_images.append(image)
all_masks.append(mask)
all_color_masks.append(color_mask)
RT_cv = RT_opengl2opencv(RT) # convert normal from opengl to opencv
all_poses.append(inv_RT(RT_cv)) # cam2world
all_w2cs.append(RT_cv)
# whether to
normal_cam_cv = normal_opengl2opencv(normal)
if normal_system == 'front':
print("the loaded normals are defined in the system of front view")
normal_world = camNormal2worldNormal(inv_RT(RT_front_cv)[:3, :3], normal_cam_cv)
elif normal_system == 'self':
print("the loaded normals are in their independent camera systems")
normal_world = camNormal2worldNormal(inv_RT(RT_cv)[:3, :3], normal_cam_cv)
all_normals.append(normal_cam_cv)
all_normals_world.append(normal_world)
if camera_type == 'ortho':
origins, dirs = get_ortho_ray_directions_origins(W=imSize[0], H=imSize[1])
elif camera_type == 'pinhole':
dirs = get_ray_directions(W=imSize[0], H=imSize[1],
fx=cam_params[0], fy=cam_params[1], cx=cam_params[2], cy=cam_params[3])
origins = dirs # occupy a position
else:
raise Exception("not support camera type")
ray_origins.append(origins)
directions.append(dirs)
if not load_color:
all_images = [normal2img(x) for x in all_normals_world]
return np.stack(all_images), np.stack(all_masks), np.stack(all_normals), \
np.stack(all_normals_world), np.stack(all_poses), np.stack(all_w2cs), np.stack(ray_origins), np.stack(directions), np.stack(all_color_masks)
class OrthoDatasetBase():
def setup(self, config, split):
self.config = config
self.split = split
self.rank = get_rank()
self.data_dir = self.config.root_dir
self.object_name = self.config.scene
self.scene = self.config.scene
self.imSize = self.config.imSize
self.load_color = True
self.img_wh = [self.imSize[0], self.imSize[1]]
self.w = self.img_wh[0]
self.h = self.img_wh[1]
self.camera_type = self.config.camera_type
self.camera_params = self.config.camera_params # [fx, fy, cx, cy]
self.view_types = ['front', 'front_right', 'right', 'back', 'left', 'front_left']
self.view_weights = torch.from_numpy(np.array(self.config.view_weights)).float().to(self.rank).view(-1)
self.view_weights = self.view_weights.view(-1,1,1).repeat(1, self.h, self.w)
if self.config.cam_pose_dir is None:
self.cam_pose_dir = "./datasets/fixed_poses"
else:
self.cam_pose_dir = self.config.cam_pose_dir
self.images_np, self.masks_np, self.normals_cam_np, self.normals_world_np, \
self.pose_all_np, self.w2c_all_np, self.origins_np, self.directions_np, self.rgb_masks_np = load_a_prediction(
self.data_dir, self.object_name, self.imSize, self.view_types,
self.load_color, self.cam_pose_dir, normal_system='front',
camera_type=self.camera_type, cam_params=self.camera_params)
self.has_mask = True
self.apply_mask = self.config.apply_mask
self.all_c2w = torch.from_numpy(self.pose_all_np)
self.all_images = torch.from_numpy(self.images_np) / 255.
self.all_fg_masks = torch.from_numpy(self.masks_np)
self.all_rgb_masks = torch.from_numpy(self.rgb_masks_np)
self.all_normals_world = torch.from_numpy(self.normals_world_np)
self.origins = torch.from_numpy(self.origins_np)
self.directions = torch.from_numpy(self.directions_np)
self.directions = self.directions.float().to(self.rank)
self.origins = self.origins.float().to(self.rank)
self.all_rgb_masks = self.all_rgb_masks.float().to(self.rank)
self.all_c2w, self.all_images, self.all_fg_masks, self.all_normals_world = \
self.all_c2w.float().to(self.rank), \
self.all_images.float().to(self.rank), \
self.all_fg_masks.float().to(self.rank), \
self.all_normals_world.float().to(self.rank)
class OrthoDataset(Dataset, OrthoDatasetBase):
def __init__(self, config, split):
self.setup(config, split)
def __len__(self):
return len(self.all_images)
def __getitem__(self, index):
return {
'index': index
}
class OrthoIterableDataset(IterableDataset, OrthoDatasetBase):
def __init__(self, config, split):
self.setup(config, split)
def __iter__(self):
while True:
yield {}
@datasets.register('ortho')
class OrthoDataModule(pl.LightningDataModule):
def __init__(self, config):
super().__init__()
self.config = config
def setup(self, stage=None):
if stage in [None, 'fit']:
self.train_dataset = OrthoIterableDataset(self.config, 'train')
if stage in [None, 'fit', 'validate']:
self.val_dataset = OrthoDataset(self.config, self.config.get('val_split', 'train'))
if stage in [None, 'test']:
self.test_dataset = OrthoDataset(self.config, self.config.get('test_split', 'test'))
if stage in [None, 'predict']:
self.predict_dataset = OrthoDataset(self.config, 'train')
def prepare_data(self):
pass
def general_loader(self, dataset, batch_size):
sampler = None
return DataLoader(
dataset,
num_workers=os.cpu_count(),
batch_size=batch_size,
pin_memory=True,
sampler=sampler
)
def train_dataloader(self):
return self.general_loader(self.train_dataset, batch_size=1)
def val_dataloader(self):
return self.general_loader(self.val_dataset, batch_size=1)
def test_dataloader(self):
return self.general_loader(self.test_dataset, batch_size=1)
def predict_dataloader(self):
return self.general_loader(self.predict_dataset, batch_size=1)