TinyChart-3B / scripts /calculate_metric.py
xzl12306's picture
first commit
d6bc023
import os
import json
import argparse
import pandas as pd
from collections import defaultdict
from tinychart.eval.eval_chartqa_metric import chartqa_evaluator, chartqapot_evaluator
from tinychart.eval.eval_chartqa_metric import chartqa_oracle_merger_evaluator, chartqa_rule_merger_evaluator
def read_jsonl(jsonl_path):
with open(jsonl_path, 'r') as f:
data = [json.loads(line) for line in f]
return data
def write_jsonl(data, jsonl_path):
with open(jsonl_path, 'w', encoding='utf-8') as f:
for item in data:
f.write(json.dumps(item) + '\n')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--input', default='./output/')
args = parser.parse_args()
result_files = os.listdir(args.input)
result_files = [f for f in result_files if f.endswith('.jsonl')]
result_files.sort()
direct_result, cot_result = None, None
dataset2metric = defaultdict(float)
for result_file in result_files:
# print(result_file)
dataset_name = '.'.join(result_file.split('.')[:-1])
file = os.path.join(args.input, result_file)
result_data = read_jsonl(file)
if 'chartqa-' in dataset_name:
direct_result, direct_acc = chartqa_evaluator(result_data, key='model_answer')
write_jsonl(direct_result, file)
dataset2metric[dataset_name] = round(direct_acc * 100, 2)
print(f'Direct Accuracy: {direct_acc}')
elif 'chartqagptpot-' in dataset_name or 'chartqatemplatepot-' in dataset_name:
pot_result, pot_acc, error_rate = chartqapot_evaluator(result_data)
write_jsonl(pot_result, file)
dataset2metric[dataset_name] = round(pot_acc * 100, 2)
print(f'PoT Accuracy: {pot_acc}')
print(f'PoT Error Rate: {error_rate}')
if direct_result is not None and pot_result is not None:
print("Calculate merging direct and pot results with simple divider")
oracle_results, oracle_acc = chartqa_oracle_merger_evaluator(direct_result, pot_result)
dataset2metric['merged-oracle'] = round(oracle_acc * 100, 2)
print(f'Oracle Merged Accuracy: {oracle_acc}')
write_jsonl(oracle_results, os.path.join(args.input, 'merged-oracle.jsonl'))
rule_results, rule_acc = chartqa_rule_merger_evaluator(direct_result, pot_result)
dataset2metric['merged-rule'] = round(rule_acc * 100, 2)
print(f'Rule Merged Accuracy: {rule_acc}')
write_jsonl(rule_results, os.path.join(args.input, 'merged-rule.jsonl'))
# save metrics into tsv with key as the first row
df = pd.DataFrame(dataset2metric, index=[0])
# if there is a metrics.tsv exists, add one in the name to avoid overwrite
tsv_name = os.path.join(args.input, 'metrics.tsv')
if os.path.exists(tsv_name):
# avoid overwrite. if there is metrics.1.tsv, name it metrics.2.tsv...
i = 1
tsv_name = os.path.join(args.input, f'metrics.{i}.tsv')
while os.path.exists(tsv_name):
i += 1
tsv_name = os.path.join(args.input, f'metrics.{i}.tsv')
df.to_csv(tsv_name, sep='\t', index=False)
print(f'Metrics saved at: {tsv_name}')
print(df)