Spaces:
Runtime error
Runtime error
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# | |
# This source code is licensed under the Apache License, Version 2.0 | |
# found in the LICENSE file in the root directory of this source tree. | |
import csv | |
from enum import Enum | |
import logging | |
import os | |
from typing import Callable, List, Optional, Tuple, Union | |
import numpy as np | |
from .extended import ExtendedVisionDataset | |
logger = logging.getLogger("dinov2") | |
_Target = int | |
class _Split(Enum): | |
TRAIN = "train" | |
VAL = "val" | |
TEST = "test" # NOTE: torchvision does not support the test split | |
def length(self) -> int: | |
split_lengths = { | |
_Split.TRAIN: 1_281_167, | |
_Split.VAL: 50_000, | |
_Split.TEST: 100_000, | |
} | |
return split_lengths[self] | |
def get_dirname(self, class_id: Optional[str] = None) -> str: | |
return self.value if class_id is None else os.path.join(self.value, class_id) | |
def get_image_relpath(self, actual_index: int, class_id: Optional[str] = None) -> str: | |
dirname = self.get_dirname(class_id) | |
if self == _Split.TRAIN: | |
basename = f"{class_id}_{actual_index}" | |
else: # self in (_Split.VAL, _Split.TEST): | |
basename = f"ILSVRC2012_{self.value}_{actual_index:08d}" | |
return os.path.join(dirname, basename + ".JPEG") | |
def parse_image_relpath(self, image_relpath: str) -> Tuple[str, int]: | |
assert self != _Split.TEST | |
dirname, filename = os.path.split(image_relpath) | |
class_id = os.path.split(dirname)[-1] | |
basename, _ = os.path.splitext(filename) | |
actual_index = int(basename.split("_")[-1]) | |
return class_id, actual_index | |
class ImageNet(ExtendedVisionDataset): | |
Target = Union[_Target] | |
Split = Union[_Split] | |
def __init__( | |
self, | |
*, | |
split: "ImageNet.Split", | |
root: str, | |
extra: str, | |
transforms: Optional[Callable] = None, | |
transform: Optional[Callable] = None, | |
target_transform: Optional[Callable] = None, | |
) -> None: | |
super().__init__(root, transforms, transform, target_transform) | |
self._extra_root = extra | |
self._split = split | |
self._entries = None | |
self._class_ids = None | |
self._class_names = None | |
def split(self) -> "ImageNet.Split": | |
return self._split | |
def _get_extra_full_path(self, extra_path: str) -> str: | |
return os.path.join(self._extra_root, extra_path) | |
def _load_extra(self, extra_path: str) -> np.ndarray: | |
extra_full_path = self._get_extra_full_path(extra_path) | |
return np.load(extra_full_path, mmap_mode="r") | |
def _save_extra(self, extra_array: np.ndarray, extra_path: str) -> None: | |
extra_full_path = self._get_extra_full_path(extra_path) | |
os.makedirs(self._extra_root, exist_ok=True) | |
np.save(extra_full_path, extra_array) | |
def _entries_path(self) -> str: | |
return f"entries-{self._split.value.upper()}.npy" | |
def _class_ids_path(self) -> str: | |
return f"class-ids-{self._split.value.upper()}.npy" | |
def _class_names_path(self) -> str: | |
return f"class-names-{self._split.value.upper()}.npy" | |
def _get_entries(self) -> np.ndarray: | |
if self._entries is None: | |
self._entries = self._load_extra(self._entries_path) | |
assert self._entries is not None | |
return self._entries | |
def _get_class_ids(self) -> np.ndarray: | |
if self._split == _Split.TEST: | |
assert False, "Class IDs are not available in TEST split" | |
if self._class_ids is None: | |
self._class_ids = self._load_extra(self._class_ids_path) | |
assert self._class_ids is not None | |
return self._class_ids | |
def _get_class_names(self) -> np.ndarray: | |
if self._split == _Split.TEST: | |
assert False, "Class names are not available in TEST split" | |
if self._class_names is None: | |
self._class_names = self._load_extra(self._class_names_path) | |
assert self._class_names is not None | |
return self._class_names | |
def find_class_id(self, class_index: int) -> str: | |
class_ids = self._get_class_ids() | |
return str(class_ids[class_index]) | |
def find_class_name(self, class_index: int) -> str: | |
class_names = self._get_class_names() | |
return str(class_names[class_index]) | |
def get_image_data(self, index: int) -> bytes: | |
entries = self._get_entries() | |
actual_index = entries[index]["actual_index"] | |
class_id = self.get_class_id(index) | |
image_relpath = self.split.get_image_relpath(actual_index, class_id) | |
image_full_path = os.path.join(self.root, image_relpath) | |
with open(image_full_path, mode="rb") as f: | |
image_data = f.read() | |
return image_data | |
def get_target(self, index: int) -> Optional[Target]: | |
entries = self._get_entries() | |
class_index = entries[index]["class_index"] | |
return None if self.split == _Split.TEST else int(class_index) | |
def get_targets(self) -> Optional[np.ndarray]: | |
entries = self._get_entries() | |
return None if self.split == _Split.TEST else entries["class_index"] | |
def get_class_id(self, index: int) -> Optional[str]: | |
entries = self._get_entries() | |
class_id = entries[index]["class_id"] | |
return None if self.split == _Split.TEST else str(class_id) | |
def get_class_name(self, index: int) -> Optional[str]: | |
entries = self._get_entries() | |
class_name = entries[index]["class_name"] | |
return None if self.split == _Split.TEST else str(class_name) | |
def __len__(self) -> int: | |
entries = self._get_entries() | |
assert len(entries) == self.split.length | |
return len(entries) | |
def _load_labels(self, labels_path: str) -> List[Tuple[str, str]]: | |
labels_full_path = os.path.join(self.root, labels_path) | |
labels = [] | |
try: | |
with open(labels_full_path, "r") as f: | |
reader = csv.reader(f) | |
for row in reader: | |
class_id, class_name = row | |
labels.append((class_id, class_name)) | |
except OSError as e: | |
raise RuntimeError(f'can not read labels file "{labels_full_path}"') from e | |
return labels | |
def _dump_entries(self) -> None: | |
split = self.split | |
if split == ImageNet.Split.TEST: | |
dataset = None | |
sample_count = split.length | |
max_class_id_length, max_class_name_length = 0, 0 | |
else: | |
labels_path = "labels.txt" | |
logger.info(f'loading labels from "{labels_path}"') | |
labels = self._load_labels(labels_path) | |
# NOTE: Using torchvision ImageFolder for consistency | |
from torchvision.datasets import ImageFolder | |
dataset_root = os.path.join(self.root, split.get_dirname()) | |
dataset = ImageFolder(dataset_root) | |
sample_count = len(dataset) | |
max_class_id_length, max_class_name_length = -1, -1 | |
for sample in dataset.samples: | |
_, class_index = sample | |
class_id, class_name = labels[class_index] | |
max_class_id_length = max(len(class_id), max_class_id_length) | |
max_class_name_length = max(len(class_name), max_class_name_length) | |
dtype = np.dtype( | |
[ | |
("actual_index", "<u4"), | |
("class_index", "<u4"), | |
("class_id", f"U{max_class_id_length}"), | |
("class_name", f"U{max_class_name_length}"), | |
] | |
) | |
entries_array = np.empty(sample_count, dtype=dtype) | |
if split == ImageNet.Split.TEST: | |
old_percent = -1 | |
for index in range(sample_count): | |
percent = 100 * (index + 1) // sample_count | |
if percent > old_percent: | |
logger.info(f"creating entries: {percent}%") | |
old_percent = percent | |
actual_index = index + 1 | |
class_index = np.uint32(-1) | |
class_id, class_name = "", "" | |
entries_array[index] = (actual_index, class_index, class_id, class_name) | |
else: | |
class_names = {class_id: class_name for class_id, class_name in labels} | |
assert dataset | |
old_percent = -1 | |
for index in range(sample_count): | |
percent = 100 * (index + 1) // sample_count | |
if percent > old_percent: | |
logger.info(f"creating entries: {percent}%") | |
old_percent = percent | |
image_full_path, class_index = dataset.samples[index] | |
image_relpath = os.path.relpath(image_full_path, self.root) | |
class_id, actual_index = split.parse_image_relpath(image_relpath) | |
class_name = class_names[class_id] | |
entries_array[index] = (actual_index, class_index, class_id, class_name) | |
logger.info(f'saving entries to "{self._entries_path}"') | |
self._save_extra(entries_array, self._entries_path) | |
def _dump_class_ids_and_names(self) -> None: | |
split = self.split | |
if split == ImageNet.Split.TEST: | |
return | |
entries_array = self._load_extra(self._entries_path) | |
max_class_id_length, max_class_name_length, max_class_index = -1, -1, -1 | |
for entry in entries_array: | |
class_index, class_id, class_name = ( | |
entry["class_index"], | |
entry["class_id"], | |
entry["class_name"], | |
) | |
max_class_index = max(int(class_index), max_class_index) | |
max_class_id_length = max(len(str(class_id)), max_class_id_length) | |
max_class_name_length = max(len(str(class_name)), max_class_name_length) | |
class_count = max_class_index + 1 | |
class_ids_array = np.empty(class_count, dtype=f"U{max_class_id_length}") | |
class_names_array = np.empty(class_count, dtype=f"U{max_class_name_length}") | |
for entry in entries_array: | |
class_index, class_id, class_name = ( | |
entry["class_index"], | |
entry["class_id"], | |
entry["class_name"], | |
) | |
class_ids_array[class_index] = class_id | |
class_names_array[class_index] = class_name | |
logger.info(f'saving class IDs to "{self._class_ids_path}"') | |
self._save_extra(class_ids_array, self._class_ids_path) | |
logger.info(f'saving class names to "{self._class_names_path}"') | |
self._save_extra(class_names_array, self._class_names_path) | |
def dump_extra(self) -> None: | |
self._dump_entries() | |
self._dump_class_ids_and_names() | |