Spaces:
Running
on
Zero
Running
on
Zero
""" | |
Copyright (c) 2022, salesforce.com, inc. | |
All rights reserved. | |
SPDX-License-Identifier: BSD-3-Clause | |
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause | |
""" | |
import logging | |
import json | |
import os | |
import minigpt4.common.dist_utils as dist_utils | |
from minigpt4.common.registry import registry | |
from minigpt4.common.vqa_tools.vqa import VQA | |
from minigpt4.common.vqa_tools.vqa_eval import VQAEval | |
from minigpt4.tasks.base_task import BaseTask | |
class VQATask(BaseTask): | |
def __init__( | |
self, | |
num_beams, | |
max_len, | |
min_len, | |
evaluate, | |
num_ans_candidates, | |
inference_method="rank", | |
prompt="", | |
): | |
super().__init__() | |
self.num_beams = num_beams | |
self.max_len = max_len | |
self.min_len = min_len | |
self.evaluate = evaluate | |
self.inference_method = inference_method | |
self.num_ans_candidates = num_ans_candidates | |
self.prompt = prompt | |
self.answer_list = None | |
self.ques_files = dict() | |
self.anno_files = dict() | |
def setup_task(cls, cfg): | |
run_cfg = cfg.run_cfg | |
num_beams = run_cfg.get("num_beams", 3) | |
max_len = run_cfg.get("max_len", 10) | |
min_len = run_cfg.get("min_len", 1) | |
evaluate = run_cfg.get("evaluate", False) | |
inference_method = run_cfg.get("inference_method", "rank") | |
num_ans_candidates = run_cfg.get("num_ans_candidates", 128) | |
prompt = run_cfg.get("prompt", "") | |
return cls( | |
num_beams=num_beams, | |
max_len=max_len, | |
min_len=min_len, | |
evaluate=evaluate, | |
num_ans_candidates=num_ans_candidates, | |
inference_method=inference_method, | |
prompt=prompt, | |
) | |
def build_datasets(self, cfg): | |
datasets = super().build_datasets(cfg) | |
# get question file, annotation file and anwser list in COCO format | |
for dataset in datasets.values(): | |
for split in dataset: | |
if ( | |
hasattr(dataset[split], "coco_fmt_qust_file") | |
and dataset[split].coco_fmt_qust_file is not None | |
): | |
self.ques_files[split] = dataset[split].coco_fmt_qust_file | |
self.anno_files[split] = dataset[split].coco_fmt_anno_file | |
try: | |
self.answer_list = dataset[split].answer_list | |
except AttributeError: | |
# if answer_list is not provided, then set it to None | |
pass | |
if len(self.ques_files) > 0: | |
assert len(self.ques_files) == len( | |
self.anno_files | |
), "Only support one split for evaluation." | |
return datasets | |
def valid_step(self, model, samples): | |
answers = model.predict_answers( | |
samples=samples, | |
answer_list=self.answer_list, | |
inference_method=self.inference_method, | |
num_beams=self.num_beams, | |
max_len=self.max_len, | |
min_len=self.min_len, | |
num_ans_candidates=self.num_ans_candidates, | |
prompt=self.prompt, | |
) | |
pred_qa_pairs = [] | |
question_id = samples["question_id"] | |
for answer, ques_id in zip(answers, question_id): | |
ques_id = int(ques_id.item()) | |
pred_qa_pairs.append({"question_id": ques_id, "answer": answer}) | |
return pred_qa_pairs | |
def after_evaluation(self, val_result, split_name, result_dir): | |
result_file = self.save_result( | |
val_result, | |
result_dir=result_dir, #registry.get_path("result_dir"), | |
filename=split_name, | |
remove_duplicate="question_id", | |
) | |
# metrics = self._report_metrics(result_file=result_file, split=split_name) | |
# return metrics | |
def _report_metrics(self, result_file, split): | |
""" | |
Use official VQA evaluation script to report metrics. | |
""" | |
metrics = {} | |
if split in self.ques_files and split in self.anno_files: | |
vqa = VQA(self.anno_files[split], self.ques_files[split]) | |
vqa_result = vqa.loadRes( | |
resFile=result_file, quesFile=self.ques_files[split] | |
) | |
# create vqaEval object by taking vqa and vqaRes | |
# n is precision of accuracy (number of places after decimal), default is 2 | |
vqa_scorer = VQAEval(vqa, vqa_result, n=2) | |
logging.info("Start VQA evaluation.") | |
vqa_scorer.evaluate() | |
# print accuracies | |
overall_acc = vqa_scorer.accuracy["overall"] | |
metrics["agg_metrics"] = overall_acc | |
logging.info("Overall Accuracy is: %.02f\n" % overall_acc) | |
logging.info("Per Answer Type Accuracy is the following:") | |
for ans_type in vqa_scorer.accuracy["perAnswerType"]: | |
logging.info( | |
"%s : %.02f" | |
% (ans_type, vqa_scorer.accuracy["perAnswerType"][ans_type]) | |
) | |
metrics[ans_type] = vqa_scorer.accuracy["perAnswerType"][ans_type] | |
with open( | |
os.path.join(registry.get_path("output_dir"), "evaluate.txt"), "a" | |
) as f: | |
f.write(json.dumps(metrics) + "\n") | |
return metrics | |
class GQATask(VQATask): | |
def valid_step(self, model, samples): | |
answers = model.predict_answers( | |
samples=samples, | |
answer_list=self.answer_list, | |
inference_method=self.inference_method, | |
num_beams=self.num_beams, | |
max_len=self.max_len, | |
min_len=self.min_len, | |
num_ans_candidates=self.num_ans_candidates, | |
prompt=self.prompt, | |
) | |
pred_qa_pairs = [] | |
question_id = samples["question_id"] | |
gt_answers = samples["answer"] | |
for answer, ques_id, gt_answer in zip(answers, question_id, gt_answers): | |
ques_id = int(ques_id.item()) | |
pred_qa_pairs.append({"question_id": ques_id, "pred_ans": answer, "gt_ans": gt_answer}) | |
return pred_qa_pairs | |
def _report_metrics(self, result_file, split): | |
""" | |
TODO: add other evaluation metrics for GQA | |
""" | |
results = json.load(open(result_file, "r")) | |
acc = [] | |
vqa_tool = VQAEval() | |
for res in results: | |
if res["gt_ans"] is None: | |
# prepare test results for leaderboard evaluation | |
self._save_result_leaderboard(results) | |
return | |
gt_ans = res["gt_ans"] | |
pred = res["pred_ans"] | |
if self.inference_method == "generate": | |
pred = vqa_tool.processPunctuation(pred) | |
pred = vqa_tool.processDigitArticle(pred) | |
vqa_acc = 1 if pred == gt_ans else 0 | |
acc.append(vqa_acc) | |
accuracy = sum(acc) / len(acc) * 100 | |
metrics = {"agg_metrics": accuracy, "acc": accuracy} | |
with open( | |
os.path.join(registry.get_path("output_dir"), "evaluate.txt"), "a" | |
) as f: | |
f.write(json.dumps(metrics) + "\n") | |
logging.info(metrics) | |
return metrics | |
class ScienceQATask(GQATask): | |
def valid_step(self, model, samples): | |
answers = model.predict_class( | |
samples=samples, | |
answer_list=self.answer_list, | |
inference_method=self.inference_method, | |
num_beams=self.num_beams, | |
max_len=self.max_len, | |
min_len=self.min_len, | |
num_ans_candidates=self.num_ans_candidates, | |
prompt=self.prompt, | |
) | |
pred_qa_pairs = [] | |
question_id = samples["question_id"] | |
gt_answers = samples["answer"] | |
for answer, ques_id, gt_answer in zip(answers, question_id, gt_answers): | |
ques_id = int(ques_id.item()) | |
pred_qa_pairs.append({"question_id": ques_id, "pred_ans": answer, "gt_ans": gt_answer}) | |
return pred_qa_pairs | |
class AOKVQATask(VQATask): | |
def valid_step(self, model, samples): | |
answers = model.predict_answers( | |
samples=samples, | |
answer_list=self.answer_list, | |
inference_method=self.inference_method, | |
num_beams=self.num_beams, | |
max_len=self.max_len, | |
min_len=self.min_len, | |
num_ans_candidates=self.num_ans_candidates, | |
) | |
pred_qa_pairs = [] | |
question_id = samples["question_id"] | |
gt_answers = samples["direct_answers"] | |
for pred_answer, ques_id, gt_answer in zip(answers, question_id, gt_answers): | |
pred_qa_pairs.append( | |
{"question_id": ques_id, "pred_ans": pred_answer, "gt_ans": gt_answer} | |
) | |
return pred_qa_pairs | |
def _report_metrics(self, result_file, split): | |
""" | |
Implementing accuracy computation for AOKVQA, see | |
https://github.com/allenai/aokvqa/blob/main/evaluation/eval_predictions.py#L45 for details. | |
""" | |
# TODO add evaluation for multi-choice | |
results = json.load(open(result_file, "r")) | |
acc = [] | |
for res in results: | |
if res["gt_ans"] is None: | |
# prepare test results for leaderboard evaluation | |
self._save_result_leaderboard(results) | |
return | |
pred = res["pred_ans"] | |
gt_ans = res["gt_ans"] | |
num_match = sum([pred == gt for gt in gt_ans]) | |
vqa_acc = min(1.0, num_match / 3.0) | |
acc.append(vqa_acc) | |
accuracy = sum(acc) / len(acc) * 100 | |
metrics = {"agg_metrics": accuracy, "acc": accuracy} | |
with open( | |
os.path.join(registry.get_path("output_dir"), "evaluate.txt"), "a" | |
) as f: | |
f.write(json.dumps(metrics) + "\n") | |
logging.info(metrics) | |
return metrics | |
def _save_result_leaderboard(self, results): | |
""" | |
Saving the results in the format required for leaderboard evaluation. | |
[TODO] add support for multi-choice. | |
""" | |
result_leaderboard = dict() | |
for res in results: | |
result_leaderboard[res["question_id"]] = { | |
"direct_answer": res["pred_ans"], | |
"multiple_choice": "", | |
} | |
result_file = registry.get_path("result_dir") + "_leaderboard.json" | |
with open(result_file, "w") as f: | |
json.dump(result_leaderboard, f) | |
logging.info(f"Saved results for leaderboard evaluation at {result_file}") | |