Spaces:
Runtime error
Runtime error
import os | |
from glob import glob | |
from collections import defaultdict | |
import numpy as np | |
from PIL import Image | |
class DAVIS(object): | |
SUBSET_OPTIONS = ['train', 'val', 'test-dev', 'test-challenge'] | |
TASKS = ['semi-supervised', 'unsupervised'] | |
DATASET_WEB = 'https://davischallenge.org/davis2017/code.html' | |
VOID_LABEL = 255 | |
def __init__(self, root, task='unsupervised', subset='val', sequences='all', resolution='480p', codalab=False): | |
""" | |
Class to read the DAVIS dataset | |
:param root: Path to the DAVIS folder that contains JPEGImages, Annotations, etc. folders. | |
:param task: Task to load the annotations, choose between semi-supervised or unsupervised. | |
:param subset: Set to load the annotations | |
:param sequences: Sequences to consider, 'all' to use all the sequences in a set. | |
:param resolution: Specify the resolution to use the dataset, choose between '480' and 'Full-Resolution' | |
""" | |
if subset not in self.SUBSET_OPTIONS: | |
raise ValueError(f'Subset should be in {self.SUBSET_OPTIONS}') | |
if task not in self.TASKS: | |
raise ValueError(f'The only tasks that are supported are {self.TASKS}') | |
self.task = task | |
self.subset = subset | |
self.root = root | |
self.img_path = os.path.join(self.root, 'JPEGImages', resolution) | |
annotations_folder = 'Annotations' if task == 'semi-supervised' else 'Annotations_unsupervised' | |
self.mask_path = os.path.join(self.root, annotations_folder, resolution) | |
year = '2019' if task == 'unsupervised' and (subset == 'test-dev' or subset == 'test-challenge') else '2017' | |
self.imagesets_path = os.path.join(self.root, 'ImageSets', year) | |
self._check_directories() | |
if sequences == 'all': | |
with open(os.path.join(self.imagesets_path, f'{self.subset}.txt'), 'r') as f: | |
tmp = f.readlines() | |
sequences_names = [x.strip() for x in tmp] | |
else: | |
sequences_names = sequences if isinstance(sequences, list) else [sequences] | |
self.sequences = defaultdict(dict) | |
for seq in sequences_names: | |
images = np.sort(glob(os.path.join(self.img_path, seq, '*.jpg'))).tolist() | |
if len(images) == 0 and not codalab: | |
raise FileNotFoundError(f'Images for sequence {seq} not found.') | |
self.sequences[seq]['images'] = images | |
masks = np.sort(glob(os.path.join(self.mask_path, seq, '*.png'))).tolist() | |
masks.extend([-1] * (len(images) - len(masks))) | |
self.sequences[seq]['masks'] = masks | |
def _check_directories(self): | |
if not os.path.exists(self.root): | |
raise FileNotFoundError(f'DAVIS not found in the specified directory, download it from {self.DATASET_WEB}') | |
if not os.path.exists(os.path.join(self.imagesets_path, f'{self.subset}.txt')): | |
raise FileNotFoundError(f'Subset sequences list for {self.subset} not found, download the missing subset ' | |
f'for the {self.task} task from {self.DATASET_WEB}') | |
if self.subset in ['train', 'val'] and not os.path.exists(self.mask_path): | |
raise FileNotFoundError(f'Annotations folder for the {self.task} task not found, download it from {self.DATASET_WEB}') | |
def get_frames(self, sequence): | |
for img, msk in zip(self.sequences[sequence]['images'], self.sequences[sequence]['masks']): | |
image = np.array(Image.open(img)) | |
mask = None if msk is None else np.array(Image.open(msk)) | |
yield image, mask | |
def _get_all_elements(self, sequence, obj_type): | |
obj = np.array(Image.open(self.sequences[sequence][obj_type][0])) | |
all_objs = np.zeros((len(self.sequences[sequence][obj_type]), *obj.shape)) | |
obj_id = [] | |
for i, obj in enumerate(self.sequences[sequence][obj_type]): | |
all_objs[i, ...] = np.array(Image.open(obj)) | |
obj_id.append(''.join(obj.split('/')[-1].split('.')[:-1])) | |
return all_objs, obj_id | |
def get_all_images(self, sequence): | |
return self._get_all_elements(sequence, 'images') | |
def get_all_masks(self, sequence, separate_objects_masks=False): | |
masks, masks_id = self._get_all_elements(sequence, 'masks') | |
masks_void = np.zeros_like(masks) | |
# Separate void and object masks | |
for i in range(masks.shape[0]): | |
masks_void[i, ...] = masks[i, ...] == 255 | |
masks[i, masks[i, ...] == 255] = 0 | |
if separate_objects_masks: | |
num_objects = int(np.max(masks[0, ...])) | |
tmp = np.ones((num_objects, *masks.shape)) | |
tmp = tmp * np.arange(1, num_objects + 1)[:, None, None, None] | |
masks = (tmp == masks[None, ...]) | |
masks = masks > 0 | |
return masks, masks_void, masks_id | |
def get_sequences(self): | |
for seq in self.sequences: | |
yield seq | |
if __name__ == '__main__': | |
from matplotlib import pyplot as plt | |
only_first_frame = True | |
subsets = ['train', 'val'] | |
for s in subsets: | |
dataset = DAVIS(root='/home/csergi/scratch2/Databases/DAVIS2017_private', subset=s) | |
for seq in dataset.get_sequences(): | |
g = dataset.get_frames(seq) | |
img, mask = next(g) | |
plt.subplot(2, 1, 1) | |
plt.title(seq) | |
plt.imshow(img) | |
plt.subplot(2, 1, 2) | |
plt.imshow(mask) | |
plt.show(block=True) | |