Spaces:
Runtime error
Runtime error
File size: 2,259 Bytes
48fa639 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 |
from os.path import expanduser
import torch
import json
import torchvision
from general_utils import get_from_repository
from general_utils import log
from torchvision import transforms
PASCAL_VOC_CLASSES_ZS = [['cattle.n.01', 'motorcycle.n.01'], ['aeroplane.n.01', 'sofa.n.01'],
['cat.n.01', 'television.n.03'], ['train.n.01', 'bottle.n.01'],
['chair.n.01', 'pot_plant.n.01']]
class PascalZeroShot(object):
def __init__(self, split, n_unseen, image_size=224) -> None:
super().__init__()
import sys
sys.path.append('third_party/JoEm')
from third_party.JoEm.data_loader.dataset import VOCSegmentation
from third_party.JoEm.data_loader import get_seen_idx, get_unseen_idx, VOC
self.pascal_classes = VOC
self.image_size = image_size
self.transform = transforms.Compose([
transforms.Resize((image_size, image_size)),
])
if split == 'train':
self.voc = VOCSegmentation(get_unseen_idx(n_unseen), get_seen_idx(n_unseen),
split=split, transform=True, transform_args=dict(base_size=312, crop_size=312),
ignore_bg=False, ignore_unseen=False, remv_unseen_img=True)
elif split == 'val':
self.voc = VOCSegmentation(get_unseen_idx(n_unseen), get_seen_idx(n_unseen),
split=split, transform=False,
ignore_bg=False, ignore_unseen=False)
self.unseen_idx = get_unseen_idx(n_unseen)
def __len__(self):
return len(self.voc)
def __getitem__(self, i):
sample = self.voc[i]
label = sample['label'].long()
all_labels = [l for l in torch.where(torch.bincount(label.flatten())>0)[0].numpy().tolist() if l != 255]
class_indices = [l for l in all_labels]
class_names = [self.pascal_classes[l] for l in all_labels]
image = self.transform(sample['image'])
label = transforms.Resize((self.image_size, self.image_size),
interpolation=torchvision.transforms.InterpolationMode.NEAREST)(label.unsqueeze(0))[0]
return (image,), (label, )
|