PSHuman / mvdiffusion /data /dreamdata.py
fffiloni's picture
Migrated from GitHub
2252f3d verified
import numpy as np
import torch
from torch.utils.data import Dataset
import json
from typing import Tuple, Optional, Any
import cv2
import random
import os
import math
from PIL import Image, ImageOps
from .normal_utils import worldNormal2camNormal, img2normal, norm_normalize
from icecream import ic
def shift_list(lst, n):
length = len(lst)
n = n % length # Ensure n is within the range of the list length
return lst[-n:] + lst[:-n]
class ObjaverseDataset(Dataset):
def __init__(self,
root_dir: str,
azi_interval: float,
random_views: int,
predict_relative_views: list,
bg_color: Any,
object_list: str,
prompt_embeds_path: str,
img_wh: Tuple[int, int],
validation: bool = False,
num_validation_samples: int = 64,
num_samples: Optional[int] = None,
invalid_list: Optional[str] = None,
trans_norm_system: bool = True, # if True, transform all normals map into the cam system of front view
# augment_data: bool = False,
side_views_rate: float = 0.,
read_normal: bool = True,
read_color: bool = False,
read_depth: bool = False,
mix_color_normal: bool = False,
random_view_and_domain: bool = False,
load_cache: bool = False,
exten: str = '.png',
elevation_list: Optional[str] = None,
with_smpl: Optional[bool] = False,
) -> None:
"""Create a dataset from a folder of images.
If you pass in a root directory it will be searched for images
ending in ext (ext can be a list)
"""
self.root_dir = root_dir
self.fixed_views = int(360 // azi_interval)
self.bg_color = bg_color
self.validation = validation
self.num_samples = num_samples
self.trans_norm_system = trans_norm_system
# self.augment_data = augment_data
self.img_wh = img_wh
self.read_normal = read_normal
self.read_color = read_color
self.read_depth = read_depth
self.mix_color_normal = mix_color_normal # mix load color and normal maps
self.random_view_and_domain = random_view_and_domain # load normal or rgb of a single view
self.random_views = random_views
self.load_cache = load_cache
self.total_views = int(self.fixed_views * (self.random_views + 1))
self.predict_relative_views = predict_relative_views
self.pred_view_nums = len(self.predict_relative_views)
self.exten = exten
self.side_views_rate = side_views_rate
self.with_smpl = with_smpl
if self.with_smpl:
self.smpl_image_path = 'smpl_image'
self.smpl_normal_path = 'smpl_normal'
ic(self.total_views)
ic(self.fixed_views)
ic(self.predict_relative_views)
ic(self.with_smpl)
self.objects = []
if object_list is not None:
for dataset_list in object_list:
with open(dataset_list, 'r') as f:
objects = json.load(f)
self.objects.extend(objects)
else:
self.objects = os.listdir(self.root_dir)
# load fixed camera poses
self.trans_cv2gl_mat = np.linalg.inv(np.array([[1, 0, 0], [0, -1, 0], [0, 0, -1]]))
self.fix_cam_poses = []
camera_path = os.path.join(self.root_dir, self.objects[0], 'camera')
for vid in range(0, self.total_views, self.random_views+1):
cam_info = np.load(f'{camera_path}/{vid:03d}.npy', allow_pickle=True).item()
assert cam_info['camera'] == 'ortho', 'Only support predict ortho camera !!!'
self.fix_cam_poses.append(cam_info['extrinsic'])
random.shuffle(self.objects)
if elevation_list:
with open(elevation_list, 'r') as f:
ele_list = [o.strip() for o in f.readlines()]
self.objects = set(ele_list) & set(self.objects)
self.all_objects = set(self.objects)
self.all_objects = list(self.all_objects)
self.validation = validation
if not validation:
self.all_objects = self.all_objects[:-num_validation_samples]
# print('Warning: you are fitting in small-scale dataset')
# self.all_objects = self.all_objects
else:
self.all_objects = self.all_objects[-num_validation_samples:]
if num_samples is not None:
self.all_objects = self.all_objects[:num_samples]
ic(len(self.all_objects))
print(f"loaded {len(self.all_objects)} in the dataset")
normal_prompt_embedding = torch.load(f'{prompt_embeds_path}/normal_embeds.pt')
color_prompt_embedding = torch.load(f'{prompt_embeds_path}/clr_embeds.pt')
if len(self.predict_relative_views) == 6:
self.normal_prompt_embedding = normal_prompt_embedding
self.color_prompt_embedding = color_prompt_embedding
elif len(self.predict_relative_views) == 4:
self.normal_prompt_embedding = torch.stack([normal_prompt_embedding[0], normal_prompt_embedding[2], normal_prompt_embedding[3], normal_prompt_embedding[4], normal_prompt_embedding[6]] , 0)
self.color_prompt_embedding = torch.stack([color_prompt_embedding[0], color_prompt_embedding[2], color_prompt_embedding[3], color_prompt_embedding[4], color_prompt_embedding[6]] , 0)
# flip back and left views
if len(self.predict_relative_views) == 6:
self.flip_views = [3, 4]
elif len(self.predict_relative_views) == 4:
self.flip_views = [2, 3]
# self.backup_data = self.__getitem_norm__(0, 'Thuman2.0/0340')
self.backup_data = self.__getitem_norm__(0)
def trans_cv2gl(self, rt):
r, t = rt[:3, :3], rt[:3, -1]
r = np.matmul(self.trans_cv2gl_mat, r)
t = np.matmul(self.trans_cv2gl_mat, t)
return np.concatenate([r, t[:, None]], axis=-1)
def cartesian_to_spherical(self, xyz):
ptsnew = np.hstack((xyz, np.zeros(xyz.shape)))
xy = xyz[:,0]**2 + xyz[:,1]**2
z = np.sqrt(xy + xyz[:,2]**2)
theta = np.arctan2(np.sqrt(xy), xyz[:,2]) # for elevation angle defined from Z-axis down
#ptsnew[:,4] = np.arctan2(xyz[:,2], np.sqrt(xy)) # for elevation angle defined from XY-plane up
azimuth = np.arctan2(xyz[:,1], xyz[:,0])
return np.array([theta, azimuth, z])
def get_T(self, target_RT, cond_RT):
R, T = target_RT[:3, :3], target_RT[:3, -1]
T_target = -R.T @ T # change to cam2world
R, T = cond_RT[:3, :3], cond_RT[:3, -1]
T_cond = -R.T @ T
theta_cond, azimuth_cond, z_cond = self.cartesian_to_spherical(T_cond[None, :])
theta_target, azimuth_target, z_target = self.cartesian_to_spherical(T_target[None, :])
d_theta = theta_target - theta_cond
d_azimuth = (azimuth_target - azimuth_cond) % (2 * math.pi)
d_z = z_target - z_cond
# d_T = torch.tensor([d_theta.item(), math.sin(d_azimuth.item()), math.cos(d_azimuth.item()), d_z.item()])
return d_theta, d_azimuth
def get_bg_color(self):
if self.bg_color == 'white':
bg_color = np.array([1., 1., 1.], dtype=np.float32)
elif self.bg_color == 'black':
bg_color = np.array([0., 0., 0.], dtype=np.float32)
elif self.bg_color == 'gray':
bg_color = np.array([0.5, 0.5, 0.5], dtype=np.float32)
elif self.bg_color == 'random':
bg_color = np.random.rand(3)
elif self.bg_color == 'three_choices':
white = np.array([1., 1., 1.], dtype=np.float32)
black = np.array([0., 0., 0.], dtype=np.float32)
gray = np.array([0.5, 0.5, 0.5], dtype=np.float32)
bg_color = random.choice([white, black, gray])
elif isinstance(self.bg_color, float):
bg_color = np.array([self.bg_color] * 3, dtype=np.float32)
else:
raise NotImplementedError
return bg_color
def crop_image(self, top_left, img):
size = max(self.img_wh)
tar_size = size - top_left * 2
alpha_np = np.asarray(img)[:, :, 3]
coords = np.argwhere(alpha_np > 0.5)
x_min, y_min = coords.min(axis=0)
x_max, y_max = coords.max(axis=0)
img = img.crop((x_min, y_min, x_max, y_max)).resize((tar_size, tar_size))
img = ImageOps.expand(img, border=(top_left, top_left, top_left, top_left), fill=0)
return img
def load_cropped_img(self, img_path, bg_color, top_left, return_type='np'):
rgba = Image.open(img_path)
rgba = self.crop_image(top_left, rgba)
rgba = np.array(rgba)
rgba = rgba.astype(np.float32) / 255. # [0, 1]
img, alpha = rgba[..., :3], rgba[..., 3:4]
img = img[...,:3] * alpha + bg_color * (1 - alpha)
if return_type == "np":
pass
elif return_type == "pt":
img = torch.from_numpy(img)
alpha = torch.from_numpy(alpha)
else:
raise NotImplementedError
return img, alpha
def load_image(self, img_path, bg_color, alpha=None, return_type='np'):
# not using cv2 as may load in uint16 format
# img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED) # [0, 255]
# img = cv2.resize(img, self.img_wh, interpolation=cv2.INTER_CUBIC)
# pil always returns uint8
rgba = np.array(Image.open(img_path).resize(self.img_wh))
rgba = rgba.astype(np.float32) / 255. # [0, 1]
img = rgba[..., :3]
if alpha is None:
assert rgba.shape[-1] == 4
alpha = rgba[..., 3:4]
assert alpha.sum() > 1e-8, 'w/o foreground'
img = img[...,:3] * alpha + bg_color * (1 - alpha)
if return_type == "np":
pass
elif return_type == "pt":
img = torch.from_numpy(img)
alpha = torch.from_numpy(alpha)
else:
raise NotImplementedError
return img, alpha
def load_normal(self, img_path, bg_color, alpha, RT_w2c_cond=None, return_type='np'):
normal_np = np.array(Image.open(img_path).resize(self.img_wh))[:, :, :3]
assert np.var(normal_np) > 1e-8, 'pure normal'
normal_cv = img2normal(normal_np)
normal_relative_cv = worldNormal2camNormal(RT_w2c_cond[:3, :3], normal_cv)
normal_relative_cv = norm_normalize(normal_relative_cv)
normal_relative_gl = normal_relative_cv
normal_relative_gl[..., 1:] = -normal_relative_gl[..., 1:]
img = (normal_relative_cv*0.5 + 0.5).astype(np.float32) # [0, 1]
if alpha.shape[-1] != 1:
alpha = alpha[:, :, None]
img = img[...,:3] * alpha + bg_color * (1 - alpha)
if return_type == "np":
pass
elif return_type == "pt":
img = torch.from_numpy(img)
else:
raise NotImplementedError
return img
def load_halfbody_normal(self, img_path, bg_color, alpha, RT_w2c_cond=None, return_type='np'):
normal_np = np.array(Image.open(img_path).resize(self.img_wh).crop((256, 0, 512, 256)).resize(self.img_wh))[:, :, :3]
assert np.var(normal_np) > 1e-8, 'pure normal'
normal_cv = img2normal(normal_np)
normal_relative_cv = worldNormal2camNormal(RT_w2c_cond[:3, :3], normal_cv)
normal_relative_cv = norm_normalize(normal_relative_cv)
# normal_relative_gl = normal_relative_cv[..., [ 0, 2, 1]]
# normal_relative_gl[..., 2] = -normal_relative_gl[..., 2]
normal_relative_gl = normal_relative_cv
normal_relative_gl[..., 1:] = -normal_relative_gl[..., 1:]
img = (normal_relative_cv*0.5 + 0.5).astype(np.float32) # [0, 1]
if alpha.shape[-1] != 1:
alpha = alpha[:, :, None]
img = img[...,:3] * alpha + bg_color * (1 - alpha)
if return_type == "np":
pass
elif return_type == "pt":
img = torch.from_numpy(img)
else:
raise NotImplementedError
return img
def __len__(self):
return len(self.all_objects)
def load_halfbody_image(self, img_path, bg_color, alpha=None, return_type='np'):
rgba = np.array(Image.open(img_path).resize(self.img_wh).crop((256, 0, 512, 256)).resize(self.img_wh))
rgba = rgba.astype(np.float32) / 255. # [0, 1]
img = rgba[..., :3]
if alpha is None:
assert rgba.shape[-1] == 4
alpha = rgba[..., 3:4]
assert alpha.sum() > 1e-8, 'w/o foreground'
img = img[...,:3] * alpha + bg_color * (1 - alpha)
if return_type == "np":
pass
elif return_type == "pt":
img = torch.from_numpy(img)
alpha = torch.from_numpy(alpha)
else:
raise NotImplementedError
return img, alpha
def __getitem_norm__(self, index, debug_object=None):
# get the bg color
bg_color = self.get_bg_color()
if debug_object is not None:
object_name = debug_object
else:
object_name = self.all_objects[index % len(self.all_objects)]
face_info = np.load(f'{self.root_dir}/{object_name}/face_info.npy', allow_pickle=True).item()
# front_fixed_idx = face_info['top3_vid'][0] // (self.random_views+1)
if self.side_views_rate > 0 and random.random() < self.side_views_rate:
front_fixed_idx = random.choice(face_info['top3_vid'])
else:
front_fixed_idx = face_info['top3_vid'][0]
with_face_idx = list(face_info.keys())
with_face_idx.remove('top3_vid')
assert front_fixed_idx in with_face_idx, 'not detected face'
if self.validation:
cond_ele0_idx = front_fixed_idx
cond_random_idx = 0
else:
if object_name[:9] == 'realistic': # This dataset set has random pose
cond_ele0_idx = random.choice(range(self.fixed_views))
cond_random_idx = random.choice(range(self.random_views+1))
else:
cond_vid = front_fixed_idx
cond_ele0_idx = cond_vid // (self.random_views + 1)
cond_ele0_vid = cond_ele0_idx * (self.random_views + 1)
cond_random_idx = 0
# condition info
cond_ele0_vid = cond_ele0_idx * (self.random_views + 1)
cond_vid = cond_ele0_vid + cond_random_idx
cond_ele0_w2c = self.fix_cam_poses[cond_ele0_idx]
img_tensors_in = [
self.load_image(f"{self.root_dir}/{object_name}/image/{cond_vid:03d}{self.exten}", bg_color, return_type='pt')[0].permute(2, 0, 1)
] * self.pred_view_nums + [
self.load_halfbody_image(f"{self.root_dir}/{object_name}/image/{cond_vid:03d}{self.exten}", bg_color, return_type='pt')[0].permute(2, 0, 1)
]
# output info
pred_vids = [(cond_ele0_vid + i * (self.random_views+1)) % self.total_views for i in self.predict_relative_views]
# pred_w2cs = [self.fix_cam_poses[(cond_ele0_idx + i) % self.fixed_views] for i in self.predict_relative_views]
img_tensors_out = []
normal_tensors_out = []
smpl_tensors_in = []
for i, vid in enumerate(pred_vids):
# output image
img_tensor, alpha_ = self.load_image(f"{self.root_dir}/{object_name}/image/{vid:03d}{self.exten}", bg_color, return_type='pt')
img_tensor = img_tensor.permute(2, 0, 1) # (3, H, W)
if i in self.flip_views: img_tensor = torch.flip(img_tensor, [2])
img_tensors_out.append(img_tensor)
# output normal
normal_tensor = self.load_normal(f"{self.root_dir}/{object_name}/normal/{vid:03d}{self.exten}", bg_color, alpha_.numpy(), RT_w2c_cond=cond_ele0_w2c[:3, :], return_type="pt").permute(2, 0, 1)
if i in self.flip_views: normal_tensor = torch.flip(normal_tensor, [2])
normal_tensors_out.append(normal_tensor)
# input smpl image
if self.with_smpl:
smpl_image_tensor, smpl_alpha_ = self.load_image(f"{self.root_dir}/{object_name}/{self.smpl_image_path}/{vid:03d}{self.exten}", bg_color, return_type='pt')
smpl_image_tensor = smpl_image_tensor.permute(2, 0, 1) # (3, H, W)
if i in self.flip_views: smpl_image_tensor = torch.flip(smpl_image_tensor, [2])
smpl_tensors_in.append(smpl_image_tensor)
# faces
if i == 0:
face_clr_out, face_alpha_out = self.load_halfbody_image(f"{self.root_dir}/{object_name}/image/{vid:03d}{self.exten}", bg_color, return_type='pt')
face_clr_out = face_clr_out.permute(2, 0, 1)
face_nrm_out = self.load_halfbody_normal(f"{self.root_dir}/{object_name}/normal/{vid:03d}{self.exten}", bg_color, face_alpha_out.numpy(), RT_w2c_cond=cond_ele0_w2c[:3, :], return_type="pt").permute(2, 0, 1)
if self.with_smpl:
face_smpl_in = self.load_halfbody_image(f"{self.root_dir}/{object_name}/{self.smpl_image_path}/{vid:03d}{self.exten}", bg_color, return_type='pt')[0].permute(2, 0, 1)
img_tensors_in = torch.stack(img_tensors_in, dim=0).float() # (Nv, 3, H, W)
img_tensors_out.append(face_clr_out)
img_tensors_out = torch.stack(img_tensors_out, dim=0).float() # (Nv, 3, H, W)
normal_tensors_out.append(face_nrm_out)
normal_tensors_out = torch.stack(normal_tensors_out, dim=0).float() # (Nv, 3, H, W)
if self.with_smpl:
smpl_tensors_in = smpl_tensors_in + [face_smpl_in]
smpl_tensors_in = torch.stack(smpl_tensors_in, dim=0).float() # (Nv, 3, H, W)
item = {
'id': object_name.replace('/', '_'),
'vid':cond_vid,
'imgs_in': img_tensors_in,
'imgs_out': img_tensors_out,
'normals_out': normal_tensors_out,
'normal_prompt_embeddings': self.normal_prompt_embedding,
'color_prompt_embeddings': self.color_prompt_embedding,
}
if self.with_smpl:
item.update({'smpl_imgs_in': smpl_tensors_in})
return item
def __getitem__(self, index):
try:
data = self.__getitem_norm__(index)
return data
except:
print("load error ", self.all_objects[index%len(self.all_objects)] )
return self.backup_data
def draw_kps(image, kps):
nose_pos = kps[2].astype(np.int32)
top_left = nose_pos - 64
bottom_right = nose_pos + 64
image_cv = image.copy()
img = cv2.rectangle(image_cv, tuple(top_left), tuple(bottom_right), (0, 255, 0), 2)
return img
if __name__ == "__main__":
# pass
from torch.utils.data import DataLoader
from torchvision.utils import make_grid
from PIL import ImageDraw, ImageFont
def draw_text(img, text, pos, color=(128, 128, 128)):
draw = ImageDraw.Draw(img)
# font = ImageFont.truetype(size= size)
font = ImageFont.load_default()
font = font.font_variant(size=10)
draw.text(pos, text, color, font=font)
return img
random.seed(11)
train_params = dict(
root_dir='/aifs4su/mmcode/lipeng/human_8view_with_smplx/',
azi_interval=45.,
random_views=0,
predict_relative_views=[0,2,4,6],
bg_color='white',
object_list=['../../data_lists/human_only_scan_with_smplx.json'],
img_wh=(768, 768),
validation=False,
num_validation_samples=10,
read_normal=True,
read_color=True,
read_depth=False,
# mix_color_normal= True,
random_view_and_domain=False,
load_cache=False,
exten='.png',
prompt_embeds_path='fixed_prompt_embeds_7view',
side_views_rate=0.1,
with_smpl=True
)
train_dataset = ObjaverseDataset(**train_params)
data_loader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=0)
if False:
case = 'CustomHumans/0593_00083_06_00101'
batch = train_dataset.__getitem_norm__(0, case)
imgs = []
obj_name = batch['id'][:8]
imgs_in = batch['imgs_in']
imgs_out = batch['imgs_out']
normal_out = batch['normals_out']
imgs_vis = torch.cat([imgs_in[0:1], imgs_in[-1:], imgs_out, normal_out], 0)
img_vis = make_grid(imgs_vis, nrow=16).permute(1, 2,0)
img_vis = (img_vis.numpy() * 255).astype(np.uint8)
img_vis = Image.fromarray(img_vis)
img_vis = draw_text(img_vis, obj_name, (5, 1))
img_vis = torch.from_numpy(np.array(img_vis)).permute(2, 0, 1) / 255.
imgs.append(img_vis)
imgs = torch.stack(imgs, dim=0)
img_grid = make_grid(imgs, nrow=4, padding=0)
img_grid = img_grid.permute(1, 2, 0).numpy()
img_grid = (img_grid * 255).astype(np.uint8)
img_grid = Image.fromarray(img_grid)
img_grid.save(f'../../debug/{case.replace("/", "_")}.png')
else:
imgs = []
i = 0
for batch in data_loader:
# print(i)
if i < 4:
i += 1
obj_name = batch['id'][0][:8]
imgs_in = batch['imgs_in'].squeeze(0)
smpl_in = batch['smpl_imgs_in'].squeeze(0)
imgs_out = batch['imgs_out'].squeeze(0)
normal_out = batch['normals_out'].squeeze(0)
imgs_vis = torch.cat([imgs_in[0:1], imgs_in[-1:], smpl_in, imgs_out, normal_out], 0)
img_vis = make_grid(imgs_vis, nrow=12).permute(1, 2,0)
img_vis = (img_vis.numpy() * 255).astype(np.uint8)
print(img_vis.shape)
# import pdb;pdb.set_trace()
# nose_kps = batch['face_kps'][0].numpy()
# print(nose_kps)
# img_vis = draw_kps(img_vis, nose_kps)
img_vis = Image.fromarray(img_vis)
img_vis = draw_text(img_vis, obj_name, (5, 1))
img_vis = torch.from_numpy(np.array(img_vis)).permute(2, 0, 1) / 255.
imgs.append(img_vis)
else:
break
imgs = torch.stack(imgs, dim=0)
img_grid = make_grid(imgs, nrow=1, padding=0)
img_grid = img_grid.permute(1, 2, 0).numpy()
img_grid = (img_grid * 255).astype(np.uint8)
img_grid = Image.fromarray(img_grid)
img_grid.save('../../debug/noele_imgs_out_10.png')