# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. import collections import json import os.path as op import numpy as np import torch from .tsv import TSVYamlDataset, find_file_path_in_yaml from .box_label_loader import BoxLabelLoader from maskrcnn_benchmark.data.datasets.coco_dt import CocoDetectionTSV class VGDetectionTSV(CocoDetectionTSV): pass def sort_key_by_val(dic): sorted_dic = sorted(dic.items(), key=lambda kv: kv[1]) return [kv[0] for kv in sorted_dic] def bbox_overlaps(anchors, gt_boxes): """ anchors: (N, 4) ndarray of float gt_boxes: (K, 4) ndarray of float overlaps: (N, K) ndarray of overlap between boxes and query_boxes """ N = anchors.size(0) K = gt_boxes.size(0) gt_boxes_area = ((gt_boxes[:, 2] - gt_boxes[:, 0] + 1) * (gt_boxes[:, 3] - gt_boxes[:, 1] + 1)).view(1, K) anchors_area = ((anchors[:, 2] - anchors[:, 0] + 1) * (anchors[:, 3] - anchors[:, 1] + 1)).view(N, 1) boxes = anchors.view(N, 1, 4).expand(N, K, 4) query_boxes = gt_boxes.view(1, K, 4).expand(N, K, 4) iw = (torch.min(boxes[:, :, 2], query_boxes[:, :, 2]) - torch.max(boxes[:, :, 0], query_boxes[:, :, 0]) + 1) iw[iw < 0] = 0 ih = (torch.min(boxes[:, :, 3], query_boxes[:, :, 3]) - torch.max(boxes[:, :, 1], query_boxes[:, :, 1]) + 1) ih[ih < 0] = 0 ua = anchors_area + gt_boxes_area - (iw * ih) overlaps = iw * ih / ua return overlaps # VG data loader for Danfei Xu's Scene graph focused format. # todo: if ordering of classes, attributes, relations changed # todo make sure to re-write the obj_classes.txt/rel_classes.txt files def _box_filter(boxes, must_overlap=False): """ Only include boxes that overlap as possible relations. If no overlapping boxes, use all of them.""" overlaps = bbox_overlaps(boxes, boxes).numpy() > 0 np.fill_diagonal(overlaps, 0) all_possib = np.ones_like(overlaps, dtype=np.bool) np.fill_diagonal(all_possib, 0) if must_overlap: possible_boxes = np.column_stack(np.where(overlaps)) if possible_boxes.size == 0: possible_boxes = np.column_stack(np.where(all_possib)) else: possible_boxes = np.column_stack(np.where(all_possib)) return possible_boxes class VGTSVDataset(TSVYamlDataset): """ Generic TSV dataset format for Object Detection. """ def __init__(self, yaml_file, extra_fields=None, transforms=None, is_load_label=True, filter_duplicate_rels=True, relation_on=False, cv2_output=False, **kwargs): if extra_fields is None: extra_fields = [] self.transforms = transforms self.is_load_label = is_load_label self.relation_on = relation_on super(VGTSVDataset, self).__init__(yaml_file, cv2_output=cv2_output) ignore_attrs = self.cfg.get("ignore_attrs", None) # construct those maps jsondict_file = find_file_path_in_yaml(self.cfg.get("jsondict", None), self.root) jsondict = json.load(open(jsondict_file, 'r')) # self.linelist_file if 'train' in op.basename(self.linelist_file): self.split = "train" elif 'test' in op.basename(self.linelist_file) \ or 'val' in op.basename(self.linelist_file) \ or 'valid' in op.basename(self.linelist_file): self.split = "test" else: raise ValueError("Split must be one of [train, test], but get {}!".format(self.linelist_file)) self.filter_duplicate_rels = filter_duplicate_rels and self.split == 'train' self.class_to_ind = jsondict['label_to_idx'] self.ind_to_class = jsondict['idx_to_label'] self.class_to_ind['__background__'] = 0 self.ind_to_class['0'] = '__background__' self.classes = sort_key_by_val(self.class_to_ind) assert (all([self.classes[i] == self.ind_to_class[str(i)] for i in range(len(self.classes))])) # writing obj classes to disk for Neural Motif model building. obj_classes_out_fn = op.splitext(self.label_file)[0] + ".obj_classes.txt" if not op.isfile(obj_classes_out_fn): with open(obj_classes_out_fn, 'w') as f: for item in self.classes: f.write("%s\n" % item) self.attribute_to_ind = jsondict['attribute_to_idx'] self.ind_to_attribute = jsondict['idx_to_attribute'] self.attribute_to_ind['__no_attribute__'] = 0 self.ind_to_attribute['0'] = '__no_attribute__' self.attributes = sort_key_by_val(self.attribute_to_ind) assert (all([self.attributes[i] == self.ind_to_attribute[str(i)] for i in range(len(self.attributes))])) self.relation_to_ind = jsondict['predicate_to_idx'] self.ind_to_relation = jsondict['idx_to_predicate'] self.relation_to_ind['__no_relation__'] = 0 self.ind_to_relation['0'] = '__no_relation__' self.relations = sort_key_by_val(self.relation_to_ind) assert (all([self.relations[i] == self.ind_to_relation[str(i)] for i in range(len(self.relations))])) # writing rel classes to disk for Neural Motif Model building. rel_classes_out_fn = op.splitext(self.label_file)[0] + '.rel_classes.txt' if not op.isfile(rel_classes_out_fn): with open(rel_classes_out_fn, 'w') as f: for item in self.relations: f.write("%s\n" % item) # label map: minus one because we will add one in BoxLabelLoader self.labelmap = {key: val - 1 for key, val in self.class_to_ind.items()} labelmap_file = find_file_path_in_yaml(self.cfg.get("labelmap_dec"), self.root) # self.labelmap_dec = load_labelmap_file(labelmap_file) if self.is_load_label: self.label_loader = BoxLabelLoader( labelmap=self.labelmap, extra_fields=extra_fields, ignore_attrs=ignore_attrs ) # get frequency prior for relations if self.relation_on: self.freq_prior_file = op.splitext(self.label_file)[0] + ".freq_prior.npy" if self.split == 'train' and not op.exists(self.freq_prior_file): print("Computing frequency prior matrix...") fg_matrix, bg_matrix = self._get_freq_prior() prob_matrix = fg_matrix.astype(np.float32) prob_matrix[:, :, 0] = bg_matrix prob_matrix[:, :, 0] += 1 prob_matrix /= np.sum(prob_matrix, 2)[:, :, None] np.save(self.freq_prior_file, prob_matrix) def _get_freq_prior(self, must_overlap=False): fg_matrix = np.zeros(( len(self.classes), len(self.classes), len(self.relations) ), dtype=np.int64) bg_matrix = np.zeros(( len(self.classes), len(self.classes), ), dtype=np.int64) for ex_ind in range(self.__len__()): target = self.get_groundtruth(ex_ind) gt_classes = target.get_field('labels').numpy() gt_relations = target.get_field('relation_labels').numpy() gt_boxes = target.bbox # For the foreground, we'll just look at everything try: o1o2 = gt_classes[gt_relations[:, :2]] for (o1, o2), gtr in zip(o1o2, gt_relations[:, 2]): fg_matrix[o1, o2, gtr] += 1 # For the background, get all of the things that overlap. o1o2_total = gt_classes[np.array( _box_filter(gt_boxes, must_overlap=must_overlap), dtype=int)] for (o1, o2) in o1o2_total: bg_matrix[o1, o2] += 1 except IndexError as e: assert len(gt_relations) == 0 if ex_ind % 20 == 0: print("processing {}/{}".format(ex_ind, self.__len__())) return fg_matrix, bg_matrix def relation_loader(self, relation_triplets, target): # relation_triplets [list of tuples]: M*3 # target: BoxList from label_loader if self.filter_duplicate_rels: # Filter out dupes! assert self.split == 'train' all_rel_sets = collections.defaultdict(list) for (o0, o1, r) in relation_triplets: all_rel_sets[(o0, o1)].append(r) relation_triplets = [(k[0], k[1], np.random.choice(v)) for k, v in all_rel_sets.items()] # get M*M pred_labels relations = torch.zeros([len(target), len(target)], dtype=torch.int64) for i in range(len(relation_triplets)): subj_id = relation_triplets[i][0] obj_id = relation_triplets[i][1] pred = relation_triplets[i][2] relations[subj_id, obj_id] = int(pred) relation_triplets = torch.tensor(relation_triplets) target.add_field("relation_labels", relation_triplets) target.add_field("pred_labels", relations) return target def get_target_from_annotations(self, annotations, img_size, idx): if self.is_load_label and annotations: target = self.label_loader(annotations['objects'], img_size) # make sure no boxes are removed assert (len(annotations['objects']) == len(target)) if self.split in ["val", "test"]: # add the difficult field target.add_field("difficult", torch.zeros(len(target), dtype=torch.int32)) # load relations if self.relation_on: target = self.relation_loader(annotations["relations"], target) return target def get_groundtruth(self, idx, call=False): # similar to __getitem__ but without transform img = self.get_image(idx) if self.cv2_output: img_size = img.shape[:2][::-1] # h, w -> w, h else: img_size = img.size # w, h annotations = self.get_annotations(idx) target = self.get_target_from_annotations(annotations, img_size, idx) if call: return img, target, annotations else: return target def apply_transforms(self, img, target=None): if self.transforms is not None: img, target = self.transforms(img, target) return img, target def map_class_id_to_class_name(self, class_id): return self.classes[class_id] def map_attribute_id_to_attribute_name(self, attribute_id): return self.attributes[attribute_id] def map_relation_id_to_relation_name(self, relation_id): return self.relations[relation_id]