Spaces:
Runtime error
Runtime error
# code taken and modified from https://github.com/amyxlase/relpose-plus-plus/blob/b33f7d5000cf2430bfcda6466c8e89bc2dcde43f/relpose/dataset/co3d_v2.py#L346) | |
import os.path as osp | |
import random | |
import numpy as np | |
import torch | |
import pytorch_lightning as pl | |
from PIL import Image, ImageFile | |
import json | |
import gzip | |
from torch.utils.data import DataLoader, Dataset | |
from torchvision import transforms | |
from pytorch3d.renderer.cameras import PerspectiveCameras | |
from pytorch3d.renderer.camera_utils import join_cameras_as_batch | |
from pytorch3d.implicitron.dataset.utils import adjust_camera_to_bbox_crop_, adjust_camera_to_image_scale_ | |
from pytorch3d.transforms import Rotate, Translate | |
CO3D_DIR = "data/training/" | |
Image.MAX_IMAGE_PIXELS = None | |
ImageFile.LOAD_TRUNCATED_IMAGES = True | |
# Added: normalize camera poses | |
def intersect_skew_line_groups(p, r, mask): | |
# p, r both of shape (B, N, n_intersected_lines, 3) | |
# mask of shape (B, N, n_intersected_lines) | |
p_intersect, r = intersect_skew_lines_high_dim(p, r, mask=mask) | |
_, p_line_intersect = _point_line_distance( | |
p, r, p_intersect[..., None, :].expand_as(p) | |
) | |
intersect_dist_squared = ((p_line_intersect - p_intersect[..., None, :]) ** 2).sum( | |
dim=-1 | |
) | |
return p_intersect, p_line_intersect, intersect_dist_squared, r | |
def intersect_skew_lines_high_dim(p, r, mask=None): | |
# Implements https://en.wikipedia.org/wiki/Skew_lines In more than two dimensions | |
dim = p.shape[-1] | |
# make sure the heading vectors are l2-normed | |
if mask is None: | |
mask = torch.ones_like(p[..., 0]) | |
r = torch.nn.functional.normalize(r, dim=-1) | |
eye = torch.eye(dim, device=p.device, dtype=p.dtype)[None, None] | |
I_min_cov = (eye - (r[..., None] * r[..., None, :])) * mask[..., None, None] | |
sum_proj = I_min_cov.matmul(p[..., None]).sum(dim=-3) | |
p_intersect = torch.linalg.lstsq(I_min_cov.sum(dim=-3), sum_proj).solution[..., 0] | |
if torch.any(torch.isnan(p_intersect)): | |
print(p_intersect) | |
assert False | |
return p_intersect, r | |
def _point_line_distance(p1, r1, p2): | |
df = p2 - p1 | |
proj_vector = df - ((df * r1).sum(dim=-1, keepdim=True) * r1) | |
line_pt_nearest = p2 - proj_vector | |
d = (proj_vector).norm(dim=-1) | |
return d, line_pt_nearest | |
def compute_optical_axis_intersection(cameras): | |
centers = cameras.get_camera_center() | |
principal_points = cameras.principal_point | |
one_vec = torch.ones((len(cameras), 1)) | |
optical_axis = torch.cat((principal_points, one_vec), -1) | |
pp = cameras.unproject_points(optical_axis, from_ndc=True, world_coordinates=True) | |
pp2 = torch.zeros((pp.shape[0], 3)) | |
for i in range(0, pp.shape[0]): | |
pp2[i] = pp[i][i] | |
directions = pp2 - centers | |
centers = centers.unsqueeze(0).unsqueeze(0) | |
directions = directions.unsqueeze(0).unsqueeze(0) | |
p_intersect, p_line_intersect, _, r = intersect_skew_line_groups( | |
p=centers, r=directions, mask=None | |
) | |
p_intersect = p_intersect.squeeze().unsqueeze(0) | |
dist = (p_intersect - centers).norm(dim=-1) | |
return p_intersect, dist, p_line_intersect, pp2, r | |
def normalize_cameras(cameras, scale=1.0): | |
""" | |
Normalizes cameras such that the optical axes point to the origin and the average | |
distance to the origin is 1. | |
Args: | |
cameras (List[camera]). | |
""" | |
# Let distance from first camera to origin be unit | |
new_cameras = cameras.clone() | |
new_transform = new_cameras.get_world_to_view_transform() | |
p_intersect, dist, p_line_intersect, pp, r = compute_optical_axis_intersection( | |
cameras | |
) | |
t = Translate(p_intersect) | |
# scale = dist.squeeze()[0] | |
scale = max(dist.squeeze()) | |
# Degenerate case | |
if scale == 0: | |
print(cameras.T) | |
print(new_transform.get_matrix()[:, 3, :3]) | |
return -1 | |
assert scale != 0 | |
new_transform = t.compose(new_transform) | |
new_cameras.R = new_transform.get_matrix()[:, :3, :3] | |
new_cameras.T = new_transform.get_matrix()[:, 3, :3] / scale | |
return new_cameras, p_intersect, p_line_intersect, pp, r | |
def centerandalign(cameras, scale=1.0): | |
""" | |
Normalizes cameras such that the optical axes point to the origin and the average | |
distance to the origin is 1. | |
Args: | |
cameras (List[camera]). | |
""" | |
# Let distance from first camera to origin be unit | |
new_cameras = cameras.clone() | |
new_transform = new_cameras.get_world_to_view_transform() | |
p_intersect, dist, p_line_intersect, pp, r = compute_optical_axis_intersection( | |
cameras | |
) | |
t = Translate(p_intersect) | |
centers = [cam.get_camera_center() for cam in new_cameras] | |
centers = torch.concat(centers, 0).cpu().numpy() | |
m = len(cameras) | |
# https://math.stackexchange.com/questions/99299/best-fitting-plane-given-a-set-of-points | |
A = np.hstack((centers[:m, :2], np.ones((m, 1)))) | |
B = centers[:m, 2:] | |
if A.shape[0] == 2: | |
x = A.T @ np.linalg.inv(A @ A.T) @ B | |
else: | |
x = np.linalg.inv(A.T @ A) @ A.T @ B | |
a, b, c = x.flatten() | |
n = np.array([a, b, 1]) | |
n /= np.linalg.norm(n) | |
# https://math.stackexchange.com/questions/180418/calculate-rotation-matrix-to-align-vector-a-to-vector-b-in-3d | |
v = np.cross(n, [0, 1, 0]) | |
s = np.linalg.norm(v) | |
c = np.dot(n, [0, 1, 0]) | |
V = np.array([[0, -v[2], v[1]], [v[2], 0, -v[0]], [-v[1], v[0], 0]]) | |
rot = torch.from_numpy(np.eye(3) + V + V @ V * (1 - c) / s**2).float() | |
scale = dist.squeeze()[0] | |
# Degenerate case | |
if scale == 0: | |
print(cameras.T) | |
print(new_transform.get_matrix()[:, 3, :3]) | |
return -1 | |
assert scale != 0 | |
rot = Rotate(rot.T) | |
new_transform = rot.compose(t).compose(new_transform) | |
new_cameras.R = new_transform.get_matrix()[:, :3, :3] | |
new_cameras.T = new_transform.get_matrix()[:, 3, :3] / scale | |
return new_cameras | |
def square_bbox(bbox, padding=0.0, astype=None): | |
""" | |
Computes a square bounding box, with optional padding parameters. | |
Args: | |
bbox: Bounding box in xyxy format (4,). | |
Returns: | |
square_bbox in xyxy format (4,). | |
""" | |
if astype is None: | |
astype = type(bbox[0]) | |
bbox = np.array(bbox) | |
center = ((bbox[:2] + bbox[2:]) / 2).round().astype(int) | |
extents = (bbox[2:] - bbox[:2]) / 2 | |
s = (max(extents) * (1 + padding)).round().astype(int) | |
square_bbox = np.array( | |
[center[0] - s, center[1] - s, center[0] + s, center[1] + s], | |
dtype=astype, | |
) | |
return square_bbox | |
class Co3dDataset(Dataset): | |
def __init__( | |
self, | |
category, | |
split="train", | |
skip=2, | |
img_size=1024, | |
num_images=4, | |
mask_images=False, | |
single_id=0, | |
bbox=False, | |
modifier_token=None, | |
addreg=False, | |
drop_ratio=0.5, | |
drop_txt=0.1, | |
categoryname=None, | |
aligncameras=False, | |
repeat=100, | |
addlen=False, | |
onlyref=False, | |
): | |
""" | |
Args: | |
category (iterable): List of categories to use. If "all" is in the list, | |
all training categories are used. | |
num_images (int): Default number of images in each batch. | |
normalize_cameras (bool): If True, normalizes cameras so that the | |
intersection of the optical axes is placed at the origin and the norm | |
of the first camera translation is 1. | |
mask_images (bool): If True, masks out the background of the images. | |
""" | |
# category = CATEGORIES | |
category = sorted(category.split(',')) | |
self.category = category | |
self.single_id = single_id | |
self.addlen = addlen | |
self.onlyref = onlyref | |
self.categoryname = categoryname | |
self.bbox = bbox | |
self.modifier_token = modifier_token | |
self.addreg = addreg | |
self.drop_txt = drop_txt | |
self.skip = skip | |
if self.addreg: | |
with open(f'data/regularization/{category[0]}_sp_generated/caption.txt', "r") as f: | |
self.regcaptions = f.read().splitlines() | |
self.reglen = len(self.regcaptions) | |
self.regimpath = f'data/regularization/{category[0]}_sp_generated' | |
self.low_quality_translations = [] | |
self.rotations = {} | |
self.category_map = {} | |
co3d_dir = CO3D_DIR | |
for c in category: | |
subset = 'fewview_dev' | |
category_dir = osp.join(co3d_dir, c) | |
frame_file = osp.join(category_dir, "frame_annotations.jgz") | |
sequence_file = osp.join(category_dir, "sequence_annotations.jgz") | |
subset_lists_file = osp.join(category_dir, f"set_lists/set_lists_{subset}.json") | |
bbox_file = osp.join(category_dir, f"{c}_bbox.jgz") | |
with open(subset_lists_file) as f: | |
subset_lists_data = json.load(f) | |
with gzip.open(sequence_file, "r") as fin: | |
sequence_data = json.loads(fin.read()) | |
with gzip.open(bbox_file, "r") as fin: | |
bbox_data = json.loads(fin.read()) | |
with gzip.open(frame_file, "r") as fin: | |
frame_data = json.loads(fin.read()) | |
frame_data_processed = {} | |
for f_data in frame_data: | |
sequence_name = f_data["sequence_name"] | |
if sequence_name not in frame_data_processed: | |
frame_data_processed[sequence_name] = {} | |
frame_data_processed[sequence_name][f_data["frame_number"]] = f_data | |
good_quality_sequences = set() | |
for seq_data in sequence_data: | |
if seq_data["viewpoint_quality_score"] > 0.5: | |
good_quality_sequences.add(seq_data["sequence_name"]) | |
for subset in ["train"]: | |
for seq_name, frame_number, filepath in subset_lists_data[subset]: | |
if seq_name not in good_quality_sequences: | |
continue | |
if seq_name not in self.rotations: | |
self.rotations[seq_name] = [] | |
self.category_map[seq_name] = c | |
mask_path = filepath.replace("images", "masks").replace(".jpg", ".png") | |
frame_data = frame_data_processed[seq_name][frame_number] | |
self.rotations[seq_name].append( | |
{ | |
"filepath": filepath, | |
"R": frame_data["viewpoint"]["R"], | |
"T": frame_data["viewpoint"]["T"], | |
"focal_length": frame_data["viewpoint"]["focal_length"], | |
"principal_point": frame_data["viewpoint"]["principal_point"], | |
"mask": mask_path, | |
"txt": "a car", | |
"bbox": bbox_data[mask_path] | |
} | |
) | |
for seq_name in self.rotations: | |
seq_data = self.rotations[seq_name] | |
cameras = PerspectiveCameras( | |
focal_length=[data["focal_length"] for data in seq_data], | |
principal_point=[data["principal_point"] for data in seq_data], | |
R=[data["R"] for data in seq_data], | |
T=[data["T"] for data in seq_data], | |
) | |
normalized_cameras, _, _, _, _ = normalize_cameras(cameras) | |
if aligncameras: | |
normalized_cameras = centerandalign(cameras) | |
if normalized_cameras == -1: | |
print("Error in normalizing cameras: camera scale was 0") | |
del self.rotations[seq_name] | |
continue | |
for i, data in enumerate(seq_data): | |
self.rotations[seq_name][i]["R"] = normalized_cameras.R[i] | |
self.rotations[seq_name][i]["T"] = normalized_cameras.T[i] | |
self.rotations[seq_name][i]["R_original"] = torch.from_numpy(np.array(seq_data[i]["R"])) | |
self.rotations[seq_name][i]["T_original"] = torch.from_numpy(np.array(seq_data[i]["T"])) | |
# Make sure translations are not ridiculous | |
if self.rotations[seq_name][i]["T"][0] + self.rotations[seq_name][i]["T"][1] + self.rotations[seq_name][i]["T"][2] > 1e5: | |
bad_seq = True | |
self.low_quality_translations.append(seq_name) | |
break | |
for seq_name in self.low_quality_translations: | |
if seq_name in self.rotations: | |
del self.rotations[seq_name] | |
self.sequence_list = list(self.rotations.keys()) | |
self.transform = transforms.Compose( | |
[ | |
transforms.Resize(img_size, interpolation=transforms.InterpolationMode.BICUBIC), | |
transforms.ToTensor(), | |
transforms.Lambda(lambda x: x * 2.0 - 1.0) | |
] | |
) | |
self.transformim = transforms.Compose( | |
[ | |
transforms.Resize(img_size, interpolation=transforms.InterpolationMode.BICUBIC), | |
transforms.CenterCrop(img_size), | |
transforms.ToTensor(), | |
transforms.Lambda(lambda x: x * 2.0 - 1.0) | |
] | |
) | |
self.transformmask = transforms.Compose( | |
[ | |
transforms.Resize(img_size // 8), | |
transforms.ToTensor(), | |
] | |
) | |
self.num_images = num_images | |
self.image_size = img_size | |
self.normalize_cameras = normalize_cameras | |
self.mask_images = mask_images | |
self.drop_ratio = drop_ratio | |
self.kernel_tensor = torch.ones((1, 1, 7, 7)) | |
self.repeat = repeat | |
print(self.sequence_list, "$$$$$$$$$$$$$$$$$$$$$") | |
self.valid_ids = np.arange(0, len(self.rotations[self.sequence_list[self.single_id]]), skip).tolist() | |
if split == 'test': | |
self.valid_ids = list(set(np.arange(0, len(self.rotations[self.sequence_list[self.single_id]])).tolist()).difference(self.valid_ids)) | |
print( | |
f"Low quality translation sequences, not used: {self.low_quality_translations}" | |
) | |
print(f"Data size: {len(self)}") | |
def __len__(self): | |
return (len(self.valid_ids))*self.repeat + (1 if self.addlen else 0) | |
def _padded_bbox(self, bbox, w, h): | |
if w < h: | |
bbox = np.array([0, 0, w, h]) | |
else: | |
bbox = np.array([0, 0, w, h]) | |
return square_bbox(bbox.astype(np.float32)) | |
def _crop_bbox(self, bbox, w, h): | |
bbox = square_bbox(bbox.astype(np.float32)) | |
side_length = bbox[2] - bbox[0] | |
center = (bbox[:2] + bbox[2:]) / 2 | |
extent = side_length / 2 | |
# Final coordinates need to be integer for cropping. | |
ul = (center - extent).round().astype(int) | |
lr = ul + np.round(2 * extent).astype(int) | |
return np.concatenate((ul, lr)) | |
def _crop_image(self, image, bbox, white_bg=False): | |
if white_bg: | |
# Only support PIL Images | |
image_crop = Image.new( | |
"RGB", (bbox[2] - bbox[0], bbox[3] - bbox[1]), (255, 255, 255) | |
) | |
image_crop.paste(image, (-bbox[0], -bbox[1])) | |
else: | |
image_crop = transforms.functional.crop( | |
image, | |
top=bbox[1], | |
left=bbox[0], | |
height=bbox[3] - bbox[1], | |
width=bbox[2] - bbox[0], | |
) | |
return image_crop | |
def __getitem__(self, index, specific_id=None, validation=False): | |
sequence_name = self.sequence_list[self.single_id] | |
metadata = self.rotations[sequence_name] | |
if validation: | |
drop_text = False | |
drop_im = False | |
else: | |
drop_im = np.random.uniform(0, 1) < self.drop_ratio | |
if not drop_im: | |
drop_text = np.random.uniform(0, 1) < self.drop_txt | |
else: | |
drop_text = False | |
size = self.image_size | |
# sample reference ids | |
listofindices = self.valid_ids.copy() | |
max_diff = len(listofindices) // (self.num_images-1) | |
if (index*self.skip) % len(metadata) in listofindices: | |
listofindices.remove((index*self.skip) % len(metadata)) | |
references = np.random.choice(np.arange(0, len(listofindices)+1, max_diff), self.num_images-1, replace=False) | |
rem = np.random.randint(0, max_diff) | |
references = [listofindices[(x + rem) % len(listofindices)] for x in references] | |
ids = [(index*self.skip) % len(metadata)] + references | |
# special case to save features corresponding to ref image as part of model buffer | |
if self.onlyref: | |
ids = references + [(index*self.skip) % len(metadata)] | |
if specific_id is not None: # remove this later | |
ids = specific_id | |
# get data | |
batch = self.get_data(index=self.single_id, ids=ids) | |
# text prompt | |
if self.modifier_token is not None: | |
name = self.category[0] if self.categoryname is None else self.categoryname | |
batch['txt'] = [f'photo of a {self.modifier_token} {name}' for _ in range(len(batch['txt']))] | |
# replace with regularization image if drop_im | |
if drop_im and self.addreg: | |
select_id = np.random.randint(0, self.reglen) | |
batch["image"] = [self.transformim(Image.open(f'{self.regimpath}/images/{select_id}.png').convert('RGB'))] | |
batch['txt'] = [self.regcaptions[select_id]] | |
batch["original_size_as_tuple"] = torch.ones_like(batch["original_size_as_tuple"])*1024 | |
# create camera class and adjust intrinsics for crop | |
cameras = [PerspectiveCameras(R=batch['R'][i].unsqueeze(0), | |
T=batch['T'][i].unsqueeze(0), | |
focal_length=batch['focal_lengths'][i].unsqueeze(0), | |
principal_point=batch['principal_points'][i].unsqueeze(0), | |
image_size=self.image_size | |
) | |
for i in range(len(ids))] | |
for i, cam in enumerate(cameras): | |
adjust_camera_to_bbox_crop_(cam, batch["original_size_as_tuple"][i, :2], batch["crop_coords"][i]) | |
adjust_camera_to_image_scale_(cam, batch["original_size_as_tuple"][i, 2:], torch.tensor([self.image_size, self.image_size])) | |
# create mask and dilated mask for mask based losses | |
batch["depth"] = batch["mask"].clone() | |
batch["mask"] = torch.clamp(torch.nn.functional.conv2d(batch["mask"], self.kernel_tensor, padding='same'), 0, 1) | |
if not self.mask_images: | |
batch["mask"] = [None for i in range(len(ids))] | |
# special case to save features corresponding to zero image | |
if index == self.__len__()-1 and self.addlen: | |
batch["image"][0] *= 0. | |
return {"jpg": batch["image"][0], | |
"txt": batch["txt"][0] if not drop_text else "", | |
"jpg_ref": batch["image"][1:] if not drop_im else torch.stack([2*torch.rand_like(batch["image"][0])-1. for _ in range(len(ids)-1)], dim=0), | |
"txt_ref": batch["txt"][1:] if not drop_im else ["" for _ in range(len(ids)-1)], | |
"pose": cameras, | |
"mask": batch["mask"][0] if not drop_im else torch.ones_like(batch["mask"][0]), | |
"mask_ref": batch["masks_padding"][1:], | |
"depth": batch["depth"][0] if len(batch["depth"]) > 0 else None, | |
"filepaths": batch["filepaths"], | |
"original_size_as_tuple": batch["original_size_as_tuple"][0][2:], | |
"target_size_as_tuple": torch.ones_like(batch["original_size_as_tuple"][0][2:])*size, | |
"crop_coords_top_left": torch.zeros_like(batch["crop_coords"][0][:2]), | |
"original_size_as_tuple_ref": batch["original_size_as_tuple"][1:][:, 2:], | |
"target_size_as_tuple_ref": torch.ones_like(batch["original_size_as_tuple"][1:][:, 2:])*size, | |
"crop_coords_top_left_ref": torch.zeros_like(batch["crop_coords"][1:][:, :2]), | |
"drop_im": torch.Tensor([1-drop_im*1.]) | |
} | |
def get_data(self, index=None, sequence_name=None, ids=(0, 1)): | |
if sequence_name is None: | |
sequence_name = self.sequence_list[index] | |
metadata = self.rotations[sequence_name] | |
category = self.category_map[sequence_name] | |
annos = [metadata[i] for i in ids] | |
images = [] | |
rotations = [] | |
translations = [] | |
focal_lengths = [] | |
principal_points = [] | |
txts = [] | |
masks = [] | |
filepaths = [] | |
images_transformed = [] | |
masks_transformed = [] | |
original_size_as_tuple = [] | |
crop_parameters = [] | |
masks_padding = [] | |
depths = [] | |
for counter, anno in enumerate(annos): | |
filepath = anno["filepath"] | |
filepaths.append(filepath) | |
image = Image.open(osp.join(CO3D_DIR, filepath)).convert("RGB") | |
mask_name = osp.basename(filepath.replace(".jpg", ".png")) | |
mask_path = osp.join( | |
CO3D_DIR, category, sequence_name, "masks", mask_name | |
) | |
mask = Image.open(mask_path).convert("L") | |
if mask.size != image.size: | |
mask = mask.resize(image.size) | |
mask_padded = Image.fromarray((np.ones_like(mask) > 0)) | |
mask = Image.fromarray((np.array(mask) > 125)) | |
masks.append(mask) | |
# crop image around object | |
w, h = image.width, image.height | |
bbox = np.array(anno["bbox"]) | |
if len(bbox) == 0: | |
bbox = np.array([0, 0, w, h]) | |
if self.bbox and counter > 0: | |
bbox = self._crop_bbox(bbox, w, h) | |
else: | |
bbox = self._padded_bbox(None, w, h) | |
image = self._crop_image(image, bbox) | |
mask = self._crop_image(mask, bbox) | |
mask_padded = self._crop_image(mask_padded, bbox) | |
masks_padding.append(self.transformmask(mask_padded)) | |
images_transformed.append(self.transform(image)) | |
masks_transformed.append(self.transformmask(mask)) | |
crop_parameters.append(torch.tensor([bbox[0], bbox[1], bbox[2] - bbox[0], bbox[3] - bbox[1] ]).int()) | |
original_size_as_tuple.append(torch.tensor([w, h, bbox[2] - bbox[0], bbox[3] - bbox[1]])) | |
images.append(image) | |
rotations.append(anno["R"]) | |
translations.append(anno["T"]) | |
focal_lengths.append(torch.tensor(anno["focal_length"])) | |
principal_points.append(torch.tensor(anno["principal_point"])) | |
txts.append(anno["txt"]) | |
images = images_transformed | |
batch = { | |
"model_id": sequence_name, | |
"category": category, | |
"original_size_as_tuple": torch.stack(original_size_as_tuple), | |
"crop_coords": torch.stack(crop_parameters), | |
"n": len(metadata), | |
"ind": torch.tensor(ids), | |
"txt": txts, | |
"filepaths": filepaths, | |
"masks_padding": torch.stack(masks_padding) if len(masks_padding) > 0 else [], | |
"depth": torch.stack(depths) if len(depths) > 0 else [], | |
} | |
batch["R"] = torch.stack(rotations) | |
batch["T"] = torch.stack(translations) | |
batch["focal_lengths"] = torch.stack(focal_lengths) | |
batch["principal_points"] = torch.stack(principal_points) | |
# Add images | |
if self.transform is None: | |
batch["image"] = images | |
else: | |
batch["image"] = torch.stack(images) | |
batch["mask"] = torch.stack(masks_transformed) | |
return batch | |
def collate_fn(batch): | |
"""A function to collate the data across batches. This function must be passed to pytorch's DataLoader to collate batches. | |
Args: | |
batch(list): List of objects returned by this class' __getitem__ function. This is given by pytorch's dataloader that calls __getitem__ | |
multiple times and expects a collated batch. | |
Returns: | |
dict: The collated dictionary representing the data in the batch. | |
""" | |
result = { | |
"jpg": [], | |
"txt": [], | |
"jpg_ref": [], | |
"txt_ref": [], | |
"pose": [], | |
"original_size_as_tuple": [], | |
"original_size_as_tuple_ref": [], | |
"crop_coords_top_left": [], | |
"crop_coords_top_left_ref": [], | |
"target_size_as_tuple_ref": [], | |
"target_size_as_tuple": [], | |
"drop_im": [], | |
"mask_ref": [], | |
} | |
if batch[0]["mask"] is not None: | |
result["mask"] = [] | |
if batch[0]["depth"] is not None: | |
result["depth"] = [] | |
for batch_obj in batch: | |
for key in result.keys(): | |
result[key].append(batch_obj[key]) | |
for key in result.keys(): | |
if not (key == 'pose' or 'txt' in key or 'size_as_tuple_ref' in key or 'coords_top_left_ref' in key): | |
result[key] = torch.stack(result[key], dim=0) | |
elif 'txt_ref' in key: | |
result[key] = [item for sublist in result[key] for item in sublist] | |
elif 'size_as_tuple_ref' in key or 'coords_top_left_ref' in key: | |
result[key] = torch.cat(result[key], dim=0) | |
elif 'pose' in key: | |
result[key] = [join_cameras_as_batch(cameras) for cameras in result[key]] | |
return result | |
class CustomDataDictLoader(pl.LightningDataModule): | |
def __init__( | |
self, | |
category, | |
batch_size, | |
mask_images=False, | |
skip=1, | |
img_size=1024, | |
num_images=4, | |
num_workers=0, | |
shuffle=True, | |
single_id=0, | |
modifier_token=None, | |
bbox=False, | |
addreg=False, | |
drop_ratio=0.5, | |
jitter=False, | |
drop_txt=0.1, | |
categoryname=None, | |
): | |
super().__init__() | |
self.batch_size = batch_size | |
self.num_workers = num_workers | |
self.shuffle = shuffle | |
self.train_dataset = Co3dDataset(category, | |
img_size=img_size, | |
mask_images=mask_images, | |
skip=skip, | |
num_images=num_images, | |
single_id=single_id, | |
modifier_token=modifier_token, | |
bbox=bbox, | |
addreg=addreg, | |
drop_ratio=drop_ratio, | |
drop_txt=drop_txt, | |
categoryname=categoryname, | |
) | |
self.val_dataset = Co3dDataset(category, | |
img_size=img_size, | |
mask_images=mask_images, | |
skip=skip, | |
num_images=2, | |
single_id=single_id, | |
modifier_token=modifier_token, | |
bbox=bbox, | |
addreg=addreg, | |
drop_ratio=0., | |
drop_txt=0., | |
categoryname=categoryname, | |
repeat=1, | |
addlen=True, | |
onlyref=True, | |
) | |
self.test_dataset = Co3dDataset(category, | |
img_size=img_size, | |
mask_images=mask_images, | |
split="test", | |
skip=skip, | |
num_images=2, | |
single_id=single_id, | |
modifier_token=modifier_token, | |
bbox=False, | |
addreg=addreg, | |
drop_ratio=0., | |
drop_txt=0., | |
categoryname=categoryname, | |
repeat=1, | |
) | |
self.collate_fn = Co3dDataset.collate_fn | |
def prepare_data(self): | |
pass | |
def train_dataloader(self): | |
return DataLoader( | |
self.train_dataset, | |
batch_size=self.batch_size, | |
shuffle=self.shuffle, | |
num_workers=self.num_workers, | |
collate_fn=self.collate_fn, | |
drop_last=True, | |
) | |
def test_dataloader(self): | |
return DataLoader( | |
self.train_dataset, | |
batch_size=self.batch_size, | |
shuffle=False, | |
num_workers=self.num_workers, | |
collate_fn=self.collate_fn, | |
) | |
def val_dataloader(self): | |
return DataLoader( | |
self.val_dataset, | |
batch_size=self.batch_size, | |
shuffle=False, | |
num_workers=self.num_workers, | |
collate_fn=self.collate_fn, | |
drop_last=True | |
) | |