import os import json import argparse import pandas as pd from collections import defaultdict from tinychart.eval.eval_chartqa_metric import chartqa_evaluator, chartqapot_evaluator from tinychart.eval.eval_chartqa_metric import chartqa_oracle_merger_evaluator, chartqa_rule_merger_evaluator def read_jsonl(jsonl_path): with open(jsonl_path, 'r') as f: data = [json.loads(line) for line in f] return data def write_jsonl(data, jsonl_path): with open(jsonl_path, 'w', encoding='utf-8') as f: for item in data: f.write(json.dumps(item) + '\n') if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--input', default='./output/') args = parser.parse_args() result_files = os.listdir(args.input) result_files = [f for f in result_files if f.endswith('.jsonl')] result_files.sort() direct_result, cot_result = None, None dataset2metric = defaultdict(float) for result_file in result_files: # print(result_file) dataset_name = '.'.join(result_file.split('.')[:-1]) file = os.path.join(args.input, result_file) result_data = read_jsonl(file) if 'chartqa-' in dataset_name: direct_result, direct_acc = chartqa_evaluator(result_data, key='model_answer') write_jsonl(direct_result, file) dataset2metric[dataset_name] = round(direct_acc * 100, 2) print(f'Direct Accuracy: {direct_acc}') elif 'chartqagptpot-' in dataset_name or 'chartqatemplatepot-' in dataset_name: pot_result, pot_acc, error_rate = chartqapot_evaluator(result_data) write_jsonl(pot_result, file) dataset2metric[dataset_name] = round(pot_acc * 100, 2) print(f'PoT Accuracy: {pot_acc}') print(f'PoT Error Rate: {error_rate}') if direct_result is not None and pot_result is not None: print("Calculate merging direct and pot results with simple divider") oracle_results, oracle_acc = chartqa_oracle_merger_evaluator(direct_result, pot_result) dataset2metric['merged-oracle'] = round(oracle_acc * 100, 2) print(f'Oracle Merged Accuracy: {oracle_acc}') write_jsonl(oracle_results, os.path.join(args.input, 'merged-oracle.jsonl')) rule_results, rule_acc = chartqa_rule_merger_evaluator(direct_result, pot_result) dataset2metric['merged-rule'] = round(rule_acc * 100, 2) print(f'Rule Merged Accuracy: {rule_acc}') write_jsonl(rule_results, os.path.join(args.input, 'merged-rule.jsonl')) # save metrics into tsv with key as the first row df = pd.DataFrame(dataset2metric, index=[0]) # if there is a metrics.tsv exists, add one in the name to avoid overwrite tsv_name = os.path.join(args.input, 'metrics.tsv') if os.path.exists(tsv_name): # avoid overwrite. if there is metrics.1.tsv, name it metrics.2.tsv... i = 1 tsv_name = os.path.join(args.input, f'metrics.{i}.tsv') while os.path.exists(tsv_name): i += 1 tsv_name = os.path.join(args.input, f'metrics.{i}.tsv') df.to_csv(tsv_name, sep='\t', index=False) print(f'Metrics saved at: {tsv_name}') print(df)