|
import argparse |
|
import json |
|
import os |
|
import re |
|
import random |
|
from collections import defaultdict |
|
|
|
|
|
def get_args(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--base-dir', type=str) |
|
parser.add_argument('--gpt4-result', type=str) |
|
parser.add_argument('--requery-result', type=str) |
|
parser.add_argument('--our-result', type=str) |
|
parser.add_argument('--output-result', type=str) |
|
parser.add_argument('--split', type=str, default='test') |
|
parser.add_argument('--options', type=list, default=["A", "B", "C", "D", "E"]) |
|
return parser.parse_args() |
|
|
|
|
|
def convert_caps(results): |
|
fakecaps = [] |
|
for result in results: |
|
image_id = result['question_id'] |
|
caption = result['text'] |
|
fakecaps.append({"image_id": int(image_id), "caption": caption}) |
|
return fakecaps |
|
|
|
|
|
def get_pred_idx(prediction, choices, options): |
|
""" |
|
Get the index (e.g. 2) from the prediction (e.g. 'C') |
|
""" |
|
if prediction in options[:len(choices)]: |
|
return options.index(prediction) |
|
else: |
|
return random.choice(range(len(choices))) |
|
|
|
|
|
if __name__ == "__main__": |
|
args = get_args() |
|
|
|
base_dir = args.base_dir |
|
split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[args.split] |
|
problems = json.load(open(os.path.join(base_dir, "problems.json"))) |
|
our_predictions = [json.loads(line) for line in open(args.our_result)] |
|
our_predictions = {pred['question_id']: pred for pred in our_predictions} |
|
split_problems = {idx: problems[idx] for idx in split_indices} |
|
|
|
requery_predictions = [json.loads(line) for line in open(args.requery_result)] |
|
requery_predictions = {pred['question_id']: pred for pred in requery_predictions} |
|
|
|
gpt4_predictions = json.load(open(args.gpt4_result))['outputs'] |
|
|
|
results = defaultdict(lambda: 0) |
|
|
|
sqa_results = {} |
|
sqa_results['acc'] = None |
|
sqa_results['correct'] = None |
|
sqa_results['count'] = None |
|
sqa_results['results'] = {} |
|
sqa_results['outputs'] = {} |
|
|
|
for prob_id, prob in split_problems.items(): |
|
if prob_id not in our_predictions: |
|
assert False |
|
if prob_id not in gpt4_predictions: |
|
assert False |
|
our_pred = our_predictions[prob_id]['text'] |
|
gpt4_pred = gpt4_predictions[prob_id] |
|
if prob_id not in requery_predictions: |
|
results['missing_requery'] += 1 |
|
requery_pred = "MISSING" |
|
else: |
|
requery_pred = requery_predictions[prob_id]['text'] |
|
|
|
pattern = re.compile(r'The answer is ([A-Z]).') |
|
our_res = pattern.findall(our_pred) |
|
if len(our_res) == 1: |
|
our_answer = our_res[0] |
|
else: |
|
our_answer = "FAILED" |
|
|
|
requery_res = pattern.findall(requery_pred) |
|
if len(requery_res) == 1: |
|
requery_answer = requery_res[0] |
|
else: |
|
requery_answer = "FAILED" |
|
|
|
gpt4_res = pattern.findall(gpt4_pred) |
|
if len(gpt4_res) == 1: |
|
gpt4_answer = gpt4_res[0] |
|
else: |
|
gpt4_answer = "FAILED" |
|
|
|
our_pred_idx = get_pred_idx(our_answer, prob['choices'], args.options) |
|
gpt4_pred_idx = get_pred_idx(gpt4_answer, prob['choices'], args.options) |
|
requery_pred_idx = get_pred_idx(requery_answer, prob['choices'], args.options) |
|
|
|
results['total'] += 1 |
|
|
|
if gpt4_answer == 'FAILED': |
|
results['gpt4_failed'] += 1 |
|
if gpt4_pred_idx == prob['answer']: |
|
results['gpt4_correct'] += 1 |
|
if our_pred_idx == prob['answer']: |
|
results['gpt4_ourvisual_correct'] += 1 |
|
elif gpt4_pred_idx == prob['answer']: |
|
results['gpt4_correct'] += 1 |
|
results['gpt4_ourvisual_correct'] += 1 |
|
|
|
if our_pred_idx == prob['answer']: |
|
results['our_correct'] += 1 |
|
|
|
if requery_answer == 'FAILED': |
|
sqa_results['results'][prob_id] = our_pred_idx |
|
if our_pred_idx == prob['answer']: |
|
results['requery_correct'] += 1 |
|
else: |
|
sqa_results['results'][prob_id] = requery_pred_idx |
|
if requery_pred_idx == prob['answer']: |
|
results['requery_correct'] += 1 |
|
else: |
|
print(f""" |
|
Question ({args.options[prob['answer']]}): {our_predictions[prob_id]['prompt']} |
|
Our ({our_answer}): {our_pred} |
|
GPT-4 ({gpt4_answer}): {gpt4_pred} |
|
Requery ({requery_answer}): {requery_pred} |
|
print("=====================================") |
|
""") |
|
|
|
if gpt4_pred_idx == prob['answer'] or our_pred_idx == prob['answer']: |
|
results['correct_upperbound'] += 1 |
|
|
|
total = results['total'] |
|
print(f'Total: {total}, Our-Correct: {results["our_correct"]}, Accuracy: {results["our_correct"] / total * 100:.2f}%') |
|
print(f'Total: {total}, GPT-4-Correct: {results["gpt4_correct"]}, Accuracy: {results["gpt4_correct"] / total * 100:.2f}%') |
|
print(f'Total: {total}, GPT-4 NO-ANS (RANDOM): {results["gpt4_failed"]}, Percentage: {results["gpt4_failed"] / total * 100:.2f}%') |
|
print(f'Total: {total}, GPT-4-OursVisual-Correct: {results["gpt4_ourvisual_correct"]}, Accuracy: {results["gpt4_ourvisual_correct"] / total * 100:.2f}%') |
|
print(f'Total: {total}, Requery-Correct: {results["requery_correct"]}, Accuracy: {results["requery_correct"] / total * 100:.2f}%') |
|
print(f'Total: {total}, Correct upper: {results["correct_upperbound"]}, Accuracy: {results["correct_upperbound"] / total * 100:.2f}%') |
|
|
|
sqa_results['acc'] = results["requery_correct"] / total * 100 |
|
sqa_results['correct'] = results["requery_correct"] |
|
sqa_results['count'] = total |
|
|
|
with open(args.output_result, 'w') as f: |
|
json.dump(sqa_results, f, indent=2) |
|
|
|
|