define-hf-demo / vidar /datasets /EUROCDataset.py
Jiading Fang
add define
fc16538
# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
import re
from collections import defaultdict
import os
from abc import ABC
from PIL import Image
import numpy as np
from vidar.utils.read import read_image
from vidar.datasets.BaseDataset import BaseDataset
from vidar.datasets.utils.misc import stack_sample
def dummy_calibration(image):
w, h = [float(d) for d in image.size]
return np.array([[1000. , 0. , w / 2. - 0.5],
[0. , 1000. , h / 2. - 0.5],
[0. , 0. , 1. ]])
def get_idx(filename):
return int(re.search(r'\d+', filename).group())
class EUROCDataset(BaseDataset, ABC):
"""
KITTI dataset class
Parameters
----------
split : String
Split file
stride : Tuple
Which context strides to use
spaceholder : String
Space pattern on input images
data_transform : Function
Transformations applied to the sample
"""
def __init__(self, split, strides=(1,), spaceholder='{:19}', **kwargs):
super().__init__(**kwargs)
self.split = split
self.spaceholder = spaceholder
self.backward_context = strides[0]
self.forward_context = strides[1]
self.has_context = self.backward_context + self.forward_context > 0
self.file_tree = defaultdict(list)
self.read_files(self.path)
self.files = []
for k, v in self.file_tree.items():
file_set = set(self.file_tree[k])
files = [fname for fname in sorted(v) if self._has_context(fname, file_set)]
self.files.extend([[k, fname] for fname in files])
def read_files(self, directory, ext=('.png', '.jpg', '.jpeg'), skip_empty=True):
"""Read input images"""
files = defaultdict(list)
for entry in os.scandir(directory):
relpath = os.path.relpath(entry.path, directory)
if entry.is_dir():
d_files = self.read_files(entry.path, ext=ext, skip_empty=skip_empty)
if skip_empty and not len(d_files):
continue
self.file_tree[entry.path] = d_files[entry.path]
elif entry.is_file():
if ext is None or entry.path.lower().endswith(tuple(ext)):
files[directory].append(relpath)
return files
def __len__(self):
return len(self.files)
def _change_idx(self, idx, filename):
"""Prepare name strings according to index"""
_, ext = os.path.splitext(os.path.basename(filename))
return self.spaceholder.format(idx) + ext
def _has_context(self, filename, file_set):
"""Check if image has context"""
context_paths = self._get_context_file_paths(filename, file_set)
return len([f in file_set for f in context_paths]) >= len(self.context)
def _get_context_file_paths(self, filename, file_set):
"""Get file path for contexts"""
fidx = get_idx(filename)
idxs = [-self.backward_context, -self.forward_context, self.backward_context, self.forward_context]
potential_files = [self._change_idx(fidx + i, filename) for i in idxs]
return [fname for fname in potential_files if fname in file_set]
def _read_rgb_context_files(self, session, filename):
"""Read context images"""
file_set = set(self.file_tree[session])
context_paths = self._get_context_file_paths(filename, file_set)
return [self._read_rgb_file(session, filename) for filename in context_paths]
def _read_rgb_file(self, session, filename):
"""Read target images"""
gray_image = read_image(os.path.join(self.path, session, filename))
gray_image_np = np.array(gray_image)
rgb_image_np = np.stack([gray_image_np for _ in range(3)], axis=2)
return Image.fromarray(rgb_image_np)
def _read_npy_depth(self, session, depth_filename):
"""Read depth from numpy file"""
depth_file_path = os.path.join(self.path, session, '../../depth_maps', depth_filename)
return np.load(depth_file_path)
def _read_depth(self, session, depth_filename):
"""Get the depth map from a file."""
return self._read_npy_depth(session, depth_filename)
def _has_depth(self, session, depth_filename):
"""Check if depth map exists"""
depth_file_path = os.path.join(self.path, session, '../../depth_maps', depth_filename)
return os.path.isfile(depth_file_path)
def __getitem__(self, idx):
"""Get dataset sample"""
samples = []
session, filename = self.files[idx]
image = self._read_rgb_file(session, filename)
sample = {
'idx': idx,
'filename': '%s_%s' % (session, os.path.splitext(filename)[0]),
'rgb': {0: image},
'intrinsics': {0: dummy_calibration(image)},
}
if self.has_context:
image_context = self._read_rgb_context_files(session, filename)
sample['rgb'].update({
key: val for key, val in zip(self.context, image_context)
})
depth_filename = filename.split('.')[0] + 'depth.npy'
if self.with_depth:
if self._has_depth(session, depth_filename):
sample['depth'] = {0: self._read_depth(session, depth_filename)}
samples.append(sample)
if self.data_transform:
samples = self.data_transform(samples)
return stack_sample(samples)
if __name__ == "__main__":
data_dir = '/data/vidar/euroc/euroc_cam/cam0'
euroc_dataset = EUROCDataset(path=data_dir,
strides=[49999872, 50000128],
context=[-1,1],
split='{:19}',
labels=['depth'],
cameras=[[0]],
)
print(len(euroc_dataset))
print('\nsample 0:')
print(euroc_dataset[0].keys())
print(euroc_dataset[0]['filename'])
print(euroc_dataset[0]['rgb'])
print(euroc_dataset[0]['intrinsics'])
print('\nsample 1:')
print(euroc_dataset[1].keys())
print(euroc_dataset[1]['filename'])
print(euroc_dataset[1]['rgb'])
print(euroc_dataset[1]['intrinsics'])