Spaces:
Runtime error
Runtime error
# coding=utf-8 | |
# Copyright 2024 The Google Research Authors. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Pascal Context Dataset.""" | |
from typing import Any, List, Tuple | |
import numpy as np | |
from PIL import Image | |
# pylint: disable=g-importing-member | |
from torchvision.datasets.voc import _VOCBase | |
PASCAL_CONTEXT_CLASSES = [ | |
'airplane', 'bag', 'bed', 'bedclothes', 'bench', 'bicycle', 'bird', 'boat', | |
'book', 'bottle', 'building', 'bus', 'cabinet', 'car', 'cat', 'ceiling', | |
'chair', 'cloth', 'computer', 'cow', 'cup', 'curtain', 'dog', 'door', | |
'fence', 'floor', 'flower', 'food', 'grass', 'ground', 'horse', 'keyboard', | |
'light', 'motorbike', 'mountain', 'mouse', 'person', 'plate', 'platform', | |
'plant', 'road', 'rock', 'sheep', 'shelves', 'sidewalk', 'sign', 'sky', | |
'snow', 'sofa', 'table', 'track', 'train', 'tree', 'truck', 'monitor', | |
'wall', 'water', 'window', 'wood'] | |
PASCAL_CONTEXT_STUFF_CLASS = [ | |
'bedclothes', 'ceiling', 'cloth', 'curtain', 'floor', 'grass', 'ground', | |
'light', 'mountain', 'platform', 'road', 'sidewalk', 'sky', 'snow', 'wall', | |
'water', 'window', 'wood', 'door', 'fence', 'rock'] | |
PASCAL_CONTEXT_THING_CLASS = [ | |
'airplane', 'bag', 'bed', 'bench', 'bicycle', 'bird', 'boat', 'book', | |
'bottle', 'building', 'bus', 'cabinet', 'car', 'cat', 'chair', 'computer', | |
'cow', 'cup', 'dog', 'flower', 'food', 'horse', 'keyboard', 'motorbike', | |
'mouse', 'person', 'plate', 'plant', 'sheep', 'shelves', 'sign', 'sofa', | |
'table', 'track', 'train', 'tree', 'truck', 'monitor'] | |
PASCAL_CONTEXT_STUFF_CLASS_ID = [ | |
3, 15, 17, 21, 25, 28, 29, 32, 34, 38, 40, 44, 46, 47, 55, 56, 57, 58, 23, | |
24, 41] | |
PASCAL_CONTEXT_THING_CLASS_ID = [ | |
0, 1, 2, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 16, 18, 19, 20, 22, 26, 27, | |
30, 31, 33, 35, 36, 37, 39, 42, 43, 45, 48, 49, 50, 51, 52, 53, 54] | |
class CONTEXTSegmentation(_VOCBase): | |
"""Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/> Segmentation Dataset. | |
Attributes: | |
root (string): Root directory of the VOC Dataset. | |
year (string, optional): The dataset year, supports years ``"2007"`` to | |
``"2012"``. | |
image_set (string, optional): Select the image_set to use, ``"train"``, | |
``"trainval"`` or ``"val"``. If ``year=="2007"``, can also be | |
``"test"``. | |
download (bool, optional): If true, downloads the dataset from the | |
internet and puts it in root directory. If dataset is already | |
downloaded, it is not downloaded again. | |
transform (callable, optional): A function/transform that takes in an PIL | |
image and returns a transformed version. E.g, ``transforms.RandomCrop`` | |
target_transform (callable, optional): A function/transform that takes in | |
the target and transforms it. | |
transforms (callable, optional): A function/transform that takes input | |
sample and its target as entry and returns a transformed version. | |
""" | |
_SPLITS_DIR = 'SegmentationContext' | |
_TARGET_DIR = 'SegmentationClassContext' | |
_TARGET_FILE_EXT = '.png' | |
def masks(self): | |
return self.targets | |
def __getitem__(self, index): | |
"""Get a sample of image and segmentation. | |
Args: | |
index (int): Index | |
Returns: | |
tuple: (image, target) where target is the image segmentation. | |
""" | |
img = Image.open(self.images[index]).convert('RGB') | |
target = Image.open(self.masks[index]) | |
if self.transforms is not None: | |
img, target = self.transforms(img, target) | |
return img, target | |
class CONTEXTDataset(CONTEXTSegmentation): | |
"""Pascal Context Dataset.""" | |
def __init__(self, root, year='2012', split='val', transform=None): | |
super(CONTEXTDataset, self).__init__( | |
root=root, | |
image_set=split, | |
year=year, | |
transform=transform, | |
download=False, | |
) | |
# self.idx_to_class = {val: key for (key, val) in CLASS2ID.items()} | |
def __getitem__(self, index): | |
image_path = self.images[index] | |
image = Image.open(image_path).convert('RGB') | |
target = np.asarray(Image.open(self.masks[index]), dtype=np.int32) | |
# transpose the target width and height | |
# target = target.transpose(1, 0) | |
if self.transforms: | |
image = self.transform(image) | |
return image, str(image_path), target, index | |