Spaces:
Runtime error
Runtime error
GLIP-BLIP-Object-Detection-VQA
/
maskrcnn_benchmark
/data
/datasets
/evaluation
/flickr
/flickr_eval.py
from maskrcnn_benchmark.structures.boxlist_ops import boxlist_iou | |
from maskrcnn_benchmark.structures.bounding_box import BoxList | |
import json | |
import numpy as np | |
import os.path as osp | |
import os | |
from prettytable import PrettyTable | |
import xml.etree.ElementTree as ET | |
from collections import defaultdict | |
from pathlib import Path | |
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union | |
import maskrcnn_benchmark.utils.mdetr_dist as dist | |
#### The following loading utilities are imported from | |
#### https://github.com/BryanPlummer/flickr30k_entities/blob/68b3d6f12d1d710f96233f6bd2b6de799d6f4e5b/flickr30k_entities_utils.py | |
# Changelog: | |
# - Added typing information | |
# - Completed docstrings | |
def get_sentence_data(filename) -> List[Dict[str, Any]]: | |
""" | |
Parses a sentence file from the Flickr30K Entities dataset | |
input: | |
filename - full file path to the sentence file to parse | |
output: | |
a list of dictionaries for each sentence with the following fields: | |
sentence - the original sentence | |
phrases - a list of dictionaries for each phrase with the | |
following fields: | |
phrase - the text of the annotated phrase | |
first_word_index - the position of the first word of | |
the phrase in the sentence | |
phrase_id - an identifier for this phrase | |
phrase_type - a list of the coarse categories this | |
phrase belongs to | |
""" | |
with open(filename, "r") as f: | |
sentences = f.read().split("\n") | |
annotations = [] | |
for sentence in sentences: | |
if not sentence: | |
continue | |
first_word = [] | |
phrases = [] | |
phrase_id = [] | |
phrase_type = [] | |
words = [] | |
current_phrase = [] | |
add_to_phrase = False | |
for token in sentence.split(): | |
if add_to_phrase: | |
if token[-1] == "]": | |
add_to_phrase = False | |
token = token[:-1] | |
current_phrase.append(token) | |
phrases.append(" ".join(current_phrase)) | |
current_phrase = [] | |
else: | |
current_phrase.append(token) | |
words.append(token) | |
else: | |
if token[0] == "[": | |
add_to_phrase = True | |
first_word.append(len(words)) | |
parts = token.split("/") | |
phrase_id.append(parts[1][3:]) | |
phrase_type.append(parts[2:]) | |
else: | |
words.append(token) | |
sentence_data = {"sentence": " ".join(words), "phrases": []} | |
for index, phrase, p_id, p_type in zip(first_word, phrases, phrase_id, phrase_type): | |
sentence_data["phrases"].append( | |
{"first_word_index": index, "phrase": phrase, "phrase_id": p_id, "phrase_type": p_type} | |
) | |
annotations.append(sentence_data) | |
return annotations | |
def get_annotations(filename) -> Dict[str, Union[int, List[str], Dict[str, List[List[int]]]]]: | |
""" | |
Parses the xml files in the Flickr30K Entities dataset | |
input: | |
filename - full file path to the annotations file to parse | |
output: | |
dictionary with the following fields: | |
scene - list of identifiers which were annotated as | |
pertaining to the whole scene | |
nobox - list of identifiers which were annotated as | |
not being visible in the image | |
boxes - a dictionary where the fields are identifiers | |
and the values are its list of boxes in the | |
[xmin ymin xmax ymax] format | |
height - int representing the height of the image | |
width - int representing the width of the image | |
depth - int representing the depth of the image | |
""" | |
tree = ET.parse(filename) | |
root = tree.getroot() | |
size_container = root.findall("size")[0] | |
anno_info: Dict[str, Union[int, List[str], Dict[str, List[List[int]]]]] = {} | |
all_boxes: Dict[str, List[List[int]]] = {} | |
all_noboxes: List[str] = [] | |
all_scenes: List[str] = [] | |
for size_element in size_container: | |
assert size_element.text | |
anno_info[size_element.tag] = int(size_element.text) | |
for object_container in root.findall("object"): | |
for names in object_container.findall("name"): | |
box_id = names.text | |
assert box_id | |
box_container = object_container.findall("bndbox") | |
if len(box_container) > 0: | |
if box_id not in all_boxes: | |
all_boxes[box_id] = [] | |
xmin = int(box_container[0].findall("xmin")[0].text) | |
ymin = int(box_container[0].findall("ymin")[0].text) | |
xmax = int(box_container[0].findall("xmax")[0].text) | |
ymax = int(box_container[0].findall("ymax")[0].text) | |
all_boxes[box_id].append([xmin, ymin, xmax, ymax]) | |
else: | |
nobndbox = int(object_container.findall("nobndbox")[0].text) | |
if nobndbox > 0: | |
all_noboxes.append(box_id) | |
scene = int(object_container.findall("scene")[0].text) | |
if scene > 0: | |
all_scenes.append(box_id) | |
anno_info["boxes"] = all_boxes | |
anno_info["nobox"] = all_noboxes | |
anno_info["scene"] = all_scenes | |
return anno_info | |
#### END of import from flickr30k_entities | |
#### Bounding box utilities imported from torchvision and converted to numpy | |
def box_area(boxes: np.array) -> np.array: | |
""" | |
Computes the area of a set of bounding boxes, which are specified by its | |
(x1, y1, x2, y2) coordinates. | |
Args: | |
boxes (Tensor[N, 4]): boxes for which the area will be computed. They | |
are expected to be in (x1, y1, x2, y2) format with | |
``0 <= x1 < x2`` and ``0 <= y1 < y2``. | |
Returns: | |
area (Tensor[N]): area for each box | |
""" | |
assert boxes.ndim == 2 and boxes.shape[-1] == 4 | |
return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) | |
# implementation from https://github.com/kuangliu/torchcv/blob/master/torchcv/utils/box.py | |
# with slight modifications | |
def _box_inter_union(boxes1: np.array, boxes2: np.array) -> Tuple[np.array, np.array]: | |
area1 = box_area(boxes1) | |
area2 = box_area(boxes2) | |
lt = np.maximum(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] | |
rb = np.minimum(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] | |
wh = (rb - lt).clip(min=0) # [N,M,2] | |
inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] | |
union = area1[:, None] + area2 - inter | |
return inter, union | |
def box_iou(boxes1: np.array, boxes2: np.array) -> np.array: | |
""" | |
Return intersection-over-union (Jaccard index) of boxes. | |
Both sets of boxes are expected to be in ``(x1, y1, x2, y2)`` format with | |
``0 <= x1 < x2`` and ``0 <= y1 < y2``. | |
Args: | |
boxes1 (Tensor[N, 4]) | |
boxes2 (Tensor[M, 4]) | |
Returns: | |
iou (Tensor[N, M]): the NxM matrix containing the pairwise IoU values for every element in boxes1 and boxes2 | |
""" | |
inter, union = _box_inter_union(boxes1, boxes2) | |
iou = inter / union | |
return iou | |
#### End of import of box utilities | |
def _merge_boxes(boxes: List[List[int]]) -> List[List[int]]: | |
""" | |
Return the boxes corresponding to the smallest enclosing box containing all the provided boxes | |
The boxes are expected in [x1, y1, x2, y2] format | |
""" | |
if len(boxes) == 1: | |
return boxes | |
np_boxes = np.asarray(boxes) | |
return [[np_boxes[:, 0].min(), np_boxes[:, 1].min(), np_boxes[:, 2].max(), np_boxes[:, 3].max()]] | |
class RecallTracker: | |
""" Utility class to track recall@k for various k, split by categories""" | |
def __init__(self, topk: Sequence[int]): | |
""" | |
Parameters: | |
- topk : tuple of ints corresponding to the recalls being tracked (eg, recall@1, recall@10, ...) | |
""" | |
self.total_byk_bycat: Dict[int, Dict[str, int]] = {k: defaultdict(int) for k in topk} | |
self.positives_byk_bycat: Dict[int, Dict[str, int]] = {k: defaultdict(int) for k in topk} | |
def add_positive(self, k: int, category: str): | |
"""Log a positive hit @k for given category""" | |
if k not in self.total_byk_bycat: | |
raise RuntimeError(f"{k} is not a valid recall threshold") | |
self.total_byk_bycat[k][category] += 1 | |
self.positives_byk_bycat[k][category] += 1 | |
def add_negative(self, k: int, category: str): | |
"""Log a negative hit @k for given category""" | |
if k not in self.total_byk_bycat: | |
raise RuntimeError(f"{k} is not a valid recall threshold") | |
self.total_byk_bycat[k][category] += 1 | |
def report(self) -> Dict[int, Dict[str, float]]: | |
"""Return a condensed report of the results as a dict of dict. | |
report[k][cat] is the recall@k for the given category | |
""" | |
report: Dict[int, Dict[str, float]] = {} | |
for k in self.total_byk_bycat: | |
assert k in self.positives_byk_bycat | |
report[k] = { | |
cat: self.positives_byk_bycat[k][cat] / self.total_byk_bycat[k][cat] for cat in self.total_byk_bycat[k] | |
} | |
return report | |
class Flickr30kEntitiesRecallEvaluator: | |
def __init__( | |
self, | |
flickr_path: str, | |
subset: str = "test", | |
topk: Sequence[int] = (1, 5, 10, -1), | |
iou_thresh: float = 0.5, | |
merge_boxes: bool = False, | |
verbose: bool = True, | |
): | |
assert subset in ["train", "test", "val"], f"Wrong flickr subset {subset}" | |
self.topk = topk | |
self.iou_thresh = iou_thresh | |
flickr_path = Path(flickr_path) | |
# We load the image ids corresponding to the current subset | |
with open(flickr_path / f"{subset}.txt") as file_d: | |
self.img_ids = [line.strip() for line in file_d] | |
if verbose: | |
print(f"Flickr subset contains {len(self.img_ids)} images") | |
# Read the box annotations for all the images | |
self.imgid2boxes: Dict[str, Dict[str, List[List[int]]]] = {} | |
if verbose: | |
print("Loading annotations...") | |
for img_id in self.img_ids: | |
anno_info = get_annotations(flickr_path / "Annotations" / f"{img_id}.xml")["boxes"] | |
if merge_boxes: | |
merged = {} | |
for phrase_id, boxes in anno_info.items(): | |
merged[phrase_id] = _merge_boxes(boxes) | |
anno_info = merged | |
self.imgid2boxes[img_id] = anno_info | |
# Read the sentences annotations | |
self.imgid2sentences: Dict[str, List[List[Optional[Dict]]]] = {} | |
if verbose: | |
print("Loading annotations...") | |
self.all_ids: List[str] = [] | |
tot_phrases = 0 | |
for img_id in self.img_ids: | |
sentence_info = get_sentence_data(flickr_path / "Sentences" / f"{img_id}.txt") | |
self.imgid2sentences[img_id] = [None for _ in range(len(sentence_info))] | |
# Some phrases don't have boxes, we filter them. | |
for sent_id, sentence in enumerate(sentence_info): | |
phrases = [phrase for phrase in sentence["phrases"] if phrase["phrase_id"] in self.imgid2boxes[img_id]] | |
if len(phrases) > 0: | |
self.imgid2sentences[img_id][sent_id] = phrases | |
tot_phrases += len(phrases) | |
self.all_ids += [ | |
f"{img_id}_{k}" for k in range(len(sentence_info)) if self.imgid2sentences[img_id][k] is not None | |
] | |
if verbose: | |
print(f"There are {tot_phrases} phrases in {len(self.all_ids)} sentences to evaluate") | |
def evaluate(self, predictions: List[Dict]): | |
evaluated_ids = set() | |
recall_tracker = RecallTracker(self.topk) | |
for pred in predictions: | |
cur_id = f"{pred['image_id']}_{pred['sentence_id']}" | |
if cur_id in evaluated_ids: | |
print( | |
"Warning, multiple predictions found for sentence" | |
f"{pred['sentence_id']} in image {pred['image_id']}" | |
) | |
continue | |
# Skip the sentences with no valid phrase | |
if cur_id not in self.all_ids: | |
if len(pred["boxes"]) != 0: | |
print( | |
f"Warning, in image {pred['image_id']} we were not expecting predictions " | |
f"for sentence {pred['sentence_id']}. Ignoring them." | |
) | |
continue | |
evaluated_ids.add(cur_id) | |
pred_boxes = pred["boxes"] | |
if str(pred["image_id"]) not in self.imgid2sentences: | |
raise RuntimeError(f"Unknown image id {pred['image_id']}") | |
if not 0 <= int(pred["sentence_id"]) < len(self.imgid2sentences[str(pred["image_id"])]): | |
raise RuntimeError(f"Unknown sentence id {pred['sentence_id']}" f" in image {pred['image_id']}") | |
target_sentence = self.imgid2sentences[str(pred["image_id"])][int(pred["sentence_id"])] | |
phrases = self.imgid2sentences[str(pred["image_id"])][int(pred["sentence_id"])] | |
if len(pred_boxes) != len(phrases): | |
raise RuntimeError( | |
f"Error, got {len(pred_boxes)} predictions, expected {len(phrases)} " | |
f"for sentence {pred['sentence_id']} in image {pred['image_id']}" | |
) | |
for cur_boxes, phrase in zip(pred_boxes, phrases): | |
target_boxes = self.imgid2boxes[str(pred["image_id"])][phrase["phrase_id"]] | |
ious = box_iou(np.asarray(cur_boxes), np.asarray(target_boxes)) | |
for k in self.topk: | |
maxi = 0 | |
if k == -1: | |
maxi = ious.max() | |
else: | |
assert k > 0 | |
maxi = ious[:k].max() | |
if maxi >= self.iou_thresh: | |
recall_tracker.add_positive(k, "all") | |
for phrase_type in phrase["phrase_type"]: | |
recall_tracker.add_positive(k, phrase_type) | |
else: | |
recall_tracker.add_negative(k, "all") | |
for phrase_type in phrase["phrase_type"]: | |
recall_tracker.add_negative(k, phrase_type) | |
if len(evaluated_ids) != len(self.all_ids): | |
print("ERROR, the number of evaluated sentence doesn't match. Missing predictions:") | |
un_processed = set(self.all_ids) - evaluated_ids | |
for missing in un_processed: | |
img_id, sent_id = missing.split("_") | |
print(f"\t sentence {sent_id} in image {img_id}") | |
raise RuntimeError("Missing predictions") | |
return recall_tracker.report() | |
class FlickrEvaluator(object): | |
def __init__( | |
self, | |
flickr_path, | |
subset, | |
top_k=(1, 5, 10, -1), | |
iou_thresh=0.5, | |
merge_boxes=False, | |
): | |
assert isinstance(top_k, (list, tuple)) | |
self.evaluator = Flickr30kEntitiesRecallEvaluator( | |
flickr_path, subset=subset, topk=top_k, iou_thresh=iou_thresh, merge_boxes=merge_boxes, verbose=False | |
) | |
self.predictions = [] | |
self.results = None | |
def accumulate(self): | |
pass | |
def update(self, predictions): | |
self.predictions += predictions | |
def synchronize_between_processes(self): | |
all_predictions = dist.all_gather(self.predictions) | |
self.predictions = sum(all_predictions, []) | |
def summarize(self): | |
if dist.is_main_process(): | |
self.results = self.evaluator.evaluate(self.predictions) | |
table = PrettyTable() | |
all_cat = sorted(list(self.results.values())[0].keys()) | |
table.field_names = ["Recall@k"] + all_cat | |
score = {} | |
for k, v in self.results.items(): | |
cur_results = [v[cat] for cat in all_cat] | |
header = "Upper_bound" if k == -1 else f"Recall@{k}" | |
for cat in all_cat: | |
score[f"{header}_{cat}"] = v[cat] | |
table.add_row([header] + cur_results) | |
print(table) | |
return score | |
return None, None | |