Spaces:
Runtime error
Runtime error
# coding=utf-8 | |
# Copyright 2024 The Google Research Authors. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Preprocess for referring datasets. | |
Adapted from | |
https://github.com/yz93/LAVT-RIS/blob/main/data/dataset_refer_bert.py | |
""" | |
# pylint: disable=all | |
from refer.refer import REFER | |
from torch.utils import data | |
class ReferDataset(data.Dataset): | |
"""Refer dataset.""" | |
def __init__( | |
self, | |
root, | |
dataset='refcoco', | |
splitBy='unc', | |
image_transforms=None, | |
target_transforms=None, | |
split='train', | |
eval_mode=False, | |
): | |
self.classes = [] | |
self.image_transforms = image_transforms | |
self.target_transforms = target_transforms | |
self.split = split | |
self.refer = REFER(root, dataset=dataset, splitBy=splitBy) | |
ref_ids = self.refer.getRefIds(split=self.split) | |
img_ids = self.refer.getImgIds(ref_ids) | |
all_imgs = self.refer.Imgs | |
self.imgs = list(all_imgs[i] for i in img_ids) | |
self.ref_ids = ref_ids | |
print(len(ref_ids)) | |
print(len(self.imgs)) | |
# print(self.imgs) | |
self.sentence_raw = [] | |
self.eval_mode = eval_mode | |
# if we are testing on a dataset, test all sentences of an object; | |
# o/w, we are validating during training, randomly sample one sentence for | |
# efficiency | |
for r in ref_ids: | |
ref = self.refer.Refs[r] | |
ref_sentences = [] | |
for el, _ in zip(ref['sentences'], ref['sent_ids']): | |
sentence_raw = el['raw'] | |
ref_sentences.append(sentence_raw) | |
self.sentence_raw.append(ref_sentences) | |
# print(len(self.sentence_raw)) | |
def get_classes(self): | |
return self.classes | |
def __len__(self): | |
return len(self.imgs) | |
def __getitem__(self, index): | |
this_img_id = self.imgs[index]['id'] | |
this_ref_ids = self.refer.getRefIds(this_img_id) | |
this_img = self.refer.Imgs[this_img_id] | |
refs = [self.refer.loadRefs(this_ref_id) for this_ref_id in this_ref_ids] | |
batch_sentences = {} | |
# batch_targets = {} | |
for ref in refs: | |
# Get sentence | |
sentence_lis = [] | |
for el, _ in zip(ref[0]['sentences'], ref[0]['sent_ids']): | |
sentence_raw = el['raw'] | |
sentence_lis.append(sentence_raw) | |
batch_sentences.update({ref[0]['ref_id']: sentence_lis}) | |
return [this_img['file_name']], batch_sentences | |
def get_ref(self): | |
name_lis = [] | |
for i in range(len(self.ref_ids)): | |
rid = self.ref_ids[i] | |
# print(rid) | |
ref = self.refer.loadRefs(rid) | |
if ref[0]['file_name'] == '': | |
print(1) | |
# print(ref[0]['file_name']) | |
# if ref[0]['file_name'] in name_lis: | |
# print("md") | |
name_lis.append(ref[0]['file_name']) | |
print(ref[0]['file_name']) | |
# print(name_lis) | |
print(len(name_lis)) | |
print(len(list(set(name_lis)))) | |