Spaces:
Runtime error
Runtime error
import numpy as np | |
import torch | |
from torch.utils.data import Dataset, DataLoader, default_collate | |
from pathlib import Path | |
from PIL import Image | |
from scipy.spatial.transform import Rotation | |
import rembg | |
from rembg import remove, new_session | |
from einops import rearrange | |
from torchvision.transforms import ToTensor, Normalize, Compose, Resize | |
from torchvision.transforms.functional import to_tensor | |
from pytorch_lightning import LightningDataModule | |
from sgm.data.colmap import read_cameras_binary, read_images_binary | |
from sgm.data.objaverse import video_collate_fn, FLATTEN_FIELDS, flatten_for_video | |
def qvec2rotmat(qvec): | |
return np.array( | |
[ | |
[ | |
1 - 2 * qvec[2] ** 2 - 2 * qvec[3] ** 2, | |
2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3], | |
2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2], | |
], | |
[ | |
2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3], | |
1 - 2 * qvec[1] ** 2 - 2 * qvec[3] ** 2, | |
2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1], | |
], | |
[ | |
2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2], | |
2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1], | |
1 - 2 * qvec[1] ** 2 - 2 * qvec[2] ** 2, | |
], | |
] | |
) | |
def qt2c2w(q, t): | |
# NOTE: remember to convert to opengl coordinate system | |
# rot = Rotation.from_quat(q).as_matrix() | |
rot = qvec2rotmat(q) | |
c2w = np.eye(4) | |
c2w[:3, :3] = np.transpose(rot) | |
c2w[:3, 3] = -np.transpose(rot) @ t | |
c2w[..., 1:3] *= -1 | |
return c2w | |
def random_crop(): | |
pass | |
class MVImageNet(Dataset): | |
def __init__( | |
self, | |
root_dir, | |
split, | |
transform, | |
reso: int = 256, | |
mask_type: str = "random", | |
cond_aug_mean=-3.0, | |
cond_aug_std=0.5, | |
condition_on_elevation=False, | |
fps_id=0.0, | |
motion_bucket_id=300.0, | |
num_frames: int = 24, | |
use_mask: bool = True, | |
load_pixelnerf: bool = False, | |
scale_pose: bool = False, | |
max_n_cond: int = 1, | |
min_n_cond: int = 1, | |
cond_on_multi: bool = False, | |
) -> None: | |
super().__init__() | |
self.root_dir = Path(root_dir) | |
self.split = split | |
avails = self.root_dir.glob("*/*") | |
self.ids = list( | |
map( | |
lambda x: str(x.relative_to(self.root_dir)), | |
filter(lambda x: x.is_dir(), avails), | |
) | |
) | |
self.transform = transform | |
self.reso = reso | |
self.num_frames = num_frames | |
self.cond_aug_mean = cond_aug_mean | |
self.cond_aug_std = cond_aug_std | |
self.condition_on_elevation = condition_on_elevation | |
self.fps_id = fps_id | |
self.motion_bucket_id = motion_bucket_id | |
self.mask_type = mask_type | |
self.use_mask = use_mask | |
self.load_pixelnerf = load_pixelnerf | |
self.scale_pose = scale_pose | |
self.max_n_cond = max_n_cond | |
self.min_n_cond = min_n_cond | |
self.cond_on_multi = cond_on_multi | |
if self.cond_on_multi: | |
assert self.min_n_cond == self.max_n_cond | |
self.session = new_session() | |
def __getitem__(self, index: int): | |
# mvimgnet starts with idx==1 | |
idx_list = np.arange(0, self.num_frames) | |
this_image_dir = self.root_dir / self.ids[index] / "images" | |
this_camera_dir = self.root_dir / self.ids[index] / "sparse/0" | |
# while not this_camera_dir.exists(): | |
# index = (index + 1) % len(self.ids) | |
# this_image_dir = self.root_dir / self.ids[index] / "images" | |
# this_camera_dir = self.root_dir / self.ids[index] / "sparse/0" | |
if not this_camera_dir.exists(): | |
index = 0 | |
this_image_dir = self.root_dir / self.ids[index] / "images" | |
this_camera_dir = self.root_dir / self.ids[index] / "sparse/0" | |
this_images = read_images_binary(this_camera_dir / "images.bin") | |
# filenames = list(map(lambda x: f"{x:03d}", this_images.keys())) | |
filenames = list(this_images.keys()) | |
if len(filenames) == 0: | |
index = 0 | |
this_image_dir = self.root_dir / self.ids[index] / "images" | |
this_camera_dir = self.root_dir / self.ids[index] / "sparse/0" | |
this_images = read_images_binary(this_camera_dir / "images.bin") | |
# filenames = list(map(lambda x: f"{x:03d}", this_images.keys())) | |
filenames = list(this_images.keys()) | |
filenames = list( | |
filter(lambda x: (this_image_dir / this_images[x].name).exists(), filenames) | |
) | |
filenames = sorted(filenames, key=lambda x: this_images[x].name) | |
# # debug | |
# names = [] | |
# for v in filenames: | |
# names.append(this_images[v].name) | |
# breakpoint() | |
while len(filenames) < self.num_frames: | |
num_surpass = self.num_frames - len(filenames) | |
filenames += list(reversed(filenames[-num_surpass:])) | |
if len(filenames) < self.num_frames: | |
print(f"\n\n{self.ids[index]}\n\n") | |
frames = [] | |
cameras = [] | |
downsampled_rgb = [] | |
for view_idx in idx_list: | |
this_id = filenames[view_idx] | |
frame = Image.open(this_image_dir / this_images[this_id].name) | |
w, h = frame.size | |
if self.mask_type == "random": | |
image_size = min(h, w) | |
left = np.random.randint(0, w - image_size + 1) | |
right = left + image_size | |
top = np.random.randint(0, h - image_size + 1) | |
bottom = top + image_size | |
## need to assign left, right, top, bottom, image_size | |
elif self.mask_type == "object": | |
pass | |
elif self.mask_type == "rembg": | |
image_size = min(h, w) | |
if ( | |
cached := this_image_dir | |
/ f"{this_images[this_id].name[:-4]}_rembg.png" | |
).exists(): | |
try: | |
mask = np.asarray(Image.open(cached, formats=["png"]))[..., 3] | |
except: | |
mask = remove(frame, session=self.session) | |
mask.save(cached) | |
mask = np.asarray(mask)[..., 3] | |
else: | |
mask = remove(frame, session=self.session) | |
mask.save(cached) | |
mask = np.asarray(mask)[..., 3] | |
# in h,w order | |
y, x = np.array(mask.nonzero()) | |
bbox_cx = x.mean() | |
bbox_cy = y.mean() | |
if bbox_cy - image_size / 2 < 0: | |
top = 0 | |
elif bbox_cy + image_size / 2 > h: | |
top = h - image_size | |
else: | |
top = int(bbox_cy - image_size / 2) | |
if bbox_cx - image_size / 2 < 0: | |
left = 0 | |
elif bbox_cx + image_size / 2 > w: | |
left = w - image_size | |
else: | |
left = int(bbox_cx - image_size / 2) | |
# top = max(int(bbox_cy - image_size / 2), 0) | |
# left = max(int(bbox_cx - image_size / 2), 0) | |
bottom = top + image_size | |
right = left + image_size | |
else: | |
raise ValueError(f"Unknown mask type: {self.mask_type}") | |
frame = frame.crop((left, top, right, bottom)) | |
frame = frame.resize((self.reso, self.reso)) | |
frames.append(self.transform(frame)) | |
if self.load_pixelnerf: | |
# extrinsics | |
extrinsics = this_images[this_id] | |
c2w = qt2c2w(extrinsics.qvec, extrinsics.tvec) | |
# intrinsics | |
intrinsics = read_cameras_binary(this_camera_dir / "cameras.bin") | |
assert len(intrinsics) == 1 | |
intrinsics = intrinsics[1] | |
f, cx, cy, _ = intrinsics.params | |
f *= 1 / image_size | |
cx -= left | |
cy -= top | |
cx *= 1 / image_size | |
cy *= 1 / image_size # all are relative values | |
intrinsics = np.array([[f, 0, cx], [0, f, cy], [0, 0, 1]]) | |
this_camera = np.zeros(25) | |
this_camera[:16] = c2w.reshape(-1) | |
this_camera[16:] = intrinsics.reshape(-1) | |
cameras.append(this_camera) | |
downsampled = frame.resize((self.reso // 8, self.reso // 8)) | |
downsampled_rgb.append((self.transform(downsampled) + 1.0) * 0.5) | |
data = dict() | |
cond_aug = np.exp( | |
np.random.randn(1)[0] * self.cond_aug_std + self.cond_aug_mean | |
) | |
frames = torch.stack(frames) | |
cond = frames[0] | |
# setting all things in data | |
data["frames"] = frames | |
data["cond_frames_without_noise"] = cond | |
data["cond_aug"] = torch.as_tensor([cond_aug] * self.num_frames) | |
data["cond_frames"] = cond + cond_aug * torch.randn_like(cond) | |
data["fps_id"] = torch.as_tensor([self.fps_id] * self.num_frames) | |
data["motion_bucket_id"] = torch.as_tensor( | |
[self.motion_bucket_id] * self.num_frames | |
) | |
data["num_video_frames"] = self.num_frames | |
data["image_only_indicator"] = torch.as_tensor([0.0] * self.num_frames) | |
if self.load_pixelnerf: | |
# TODO: normalize camera poses | |
data["pixelnerf_input"] = dict() | |
data["pixelnerf_input"]["frames"] = frames | |
data["pixelnerf_input"]["rgb"] = torch.stack(downsampled_rgb) | |
cameras = torch.from_numpy(np.stack(cameras)).float() | |
if self.scale_pose: | |
c2ws = cameras[..., :16].reshape(-1, 4, 4) | |
center = c2ws[:, :3, 3].mean(0) | |
radius = (c2ws[:, :3, 3] - center).norm(dim=-1).max() | |
scale = 1.5 / radius | |
c2ws[..., :3, 3] = (c2ws[..., :3, 3] - center) * scale | |
cameras[..., :16] = c2ws.reshape(-1, 16) | |
# if self.max_n_cond > 1: | |
# # TODO implement this | |
# n_cond = np.random.randint(1, self.max_n_cond + 1) | |
# # debug | |
# source_index = [0] | |
# if n_cond > 1: | |
# source_index += np.random.choice( | |
# np.arange(1, self.num_frames), | |
# self.max_n_cond - 1, | |
# replace=False, | |
# ).tolist() | |
# data["pixelnerf_input"]["source_index"] = torch.as_tensor( | |
# source_index | |
# ) | |
# data["pixelnerf_input"]["n_cond"] = n_cond | |
# data["pixelnerf_input"]["source_images"] = frames[source_index] | |
# data["pixelnerf_input"]["source_cameras"] = cameras[source_index] | |
data["pixelnerf_input"]["cameras"] = cameras | |
return data | |
def __len__(self): | |
return len(self.ids) | |
def collate_fn(self, batch): | |
# a hack to add source index and keep consistent within a batch | |
if self.max_n_cond > 1: | |
# TODO implement this | |
n_cond = np.random.randint(self.min_n_cond, self.max_n_cond + 1) | |
# debug | |
# source_index = [0] | |
if n_cond > 1: | |
for b in batch: | |
source_index = [0] + np.random.choice( | |
np.arange(1, self.num_frames), | |
self.max_n_cond - 1, | |
replace=False, | |
).tolist() | |
b["pixelnerf_input"]["source_index"] = torch.as_tensor(source_index) | |
b["pixelnerf_input"]["n_cond"] = n_cond | |
b["pixelnerf_input"]["source_images"] = b["frames"][source_index] | |
b["pixelnerf_input"]["source_cameras"] = b["pixelnerf_input"][ | |
"cameras" | |
][source_index] | |
if self.cond_on_multi: | |
b["cond_frames_without_noise"] = b["frames"][source_index] | |
ret = video_collate_fn(batch) | |
if self.cond_on_multi: | |
ret["cond_frames_without_noise"] = rearrange(ret["cond_frames_without_noise"], "b t ... -> (b t) ...") | |
return ret | |
class MVImageNetFixedCond(MVImageNet): | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
class MVImageNetDataset(LightningDataModule): | |
def __init__( | |
self, | |
root_dir, | |
batch_size=2, | |
shuffle=True, | |
num_workers=10, | |
prefetch_factor=2, | |
**kwargs, | |
): | |
super().__init__() | |
self.batch_size = batch_size | |
self.num_workers = num_workers | |
self.prefetch_factor = prefetch_factor | |
self.shuffle = shuffle | |
self.transform = Compose( | |
[ | |
ToTensor(), | |
Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), | |
] | |
) | |
self.train_dataset = MVImageNet( | |
root_dir=root_dir, | |
split="train", | |
transform=self.transform, | |
**kwargs, | |
) | |
self.test_dataset = MVImageNet( | |
root_dir=root_dir, | |
split="test", | |
transform=self.transform, | |
**kwargs, | |
) | |
def train_dataloader(self): | |
def worker_init_fn(worker_id): | |
np.random.seed(np.random.get_state()[1][0]) | |
return DataLoader( | |
self.train_dataset, | |
batch_size=self.batch_size, | |
shuffle=self.shuffle, | |
num_workers=self.num_workers, | |
prefetch_factor=self.prefetch_factor, | |
collate_fn=self.train_dataset.collate_fn, | |
) | |
def test_dataloader(self): | |
return DataLoader( | |
self.test_dataset, | |
batch_size=self.batch_size, | |
shuffle=self.shuffle, | |
num_workers=self.num_workers, | |
prefetch_factor=self.prefetch_factor, | |
collate_fn=self.test_dataset.collate_fn, | |
) | |
def val_dataloader(self): | |
return DataLoader( | |
self.test_dataset, | |
batch_size=self.batch_size, | |
shuffle=self.shuffle, | |
num_workers=self.num_workers, | |
prefetch_factor=self.prefetch_factor, | |
collate_fn=video_collate_fn, | |
) | |