File size: 3,243 Bytes
d6bc023
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import os
import json
import argparse
import pandas as pd
from collections import defaultdict
from tinychart.eval.eval_chartqa_metric import chartqa_evaluator, chartqapot_evaluator
from tinychart.eval.eval_chartqa_metric import chartqa_oracle_merger_evaluator, chartqa_rule_merger_evaluator

def read_jsonl(jsonl_path):
    with open(jsonl_path, 'r') as f:
        data = [json.loads(line) for line in f]
    return data

def write_jsonl(data, jsonl_path):
    with open(jsonl_path, 'w', encoding='utf-8') as f:
        for item in data:
            f.write(json.dumps(item) + '\n')

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--input', default='./output/')

    args = parser.parse_args()

    result_files = os.listdir(args.input)
    result_files = [f for f in result_files if f.endswith('.jsonl')]
    result_files.sort()
    direct_result, cot_result = None, None

    dataset2metric = defaultdict(float)
    for result_file in result_files:
        # print(result_file)
        dataset_name = '.'.join(result_file.split('.')[:-1])
        file = os.path.join(args.input, result_file)
        result_data = read_jsonl(file)
        if 'chartqa-' in dataset_name:
            direct_result, direct_acc = chartqa_evaluator(result_data, key='model_answer')
            write_jsonl(direct_result, file)
            dataset2metric[dataset_name] = round(direct_acc * 100, 2)
            print(f'Direct Accuracy: {direct_acc}')
        elif 'chartqagptpot-' in dataset_name or 'chartqatemplatepot-' in dataset_name:
            pot_result, pot_acc, error_rate = chartqapot_evaluator(result_data)
            write_jsonl(pot_result, file)
            dataset2metric[dataset_name] = round(pot_acc * 100, 2)
            print(f'PoT Accuracy: {pot_acc}')
            print(f'PoT Error Rate: {error_rate}')

    if direct_result is not None and pot_result is not None:
        print("Calculate merging direct and pot results with simple divider")
        oracle_results, oracle_acc = chartqa_oracle_merger_evaluator(direct_result, pot_result)
        dataset2metric['merged-oracle'] = round(oracle_acc * 100, 2)
        print(f'Oracle Merged Accuracy: {oracle_acc}')
        write_jsonl(oracle_results, os.path.join(args.input, 'merged-oracle.jsonl'))
        rule_results, rule_acc = chartqa_rule_merger_evaluator(direct_result, pot_result)
        dataset2metric['merged-rule'] = round(rule_acc * 100, 2)
        print(f'Rule Merged Accuracy: {rule_acc}')
        write_jsonl(rule_results, os.path.join(args.input, 'merged-rule.jsonl'))
    
    # save metrics into tsv with key as the first row
    df = pd.DataFrame(dataset2metric, index=[0])
    # if there is a metrics.tsv exists, add one in the name to avoid overwrite
    tsv_name = os.path.join(args.input, 'metrics.tsv')
    if os.path.exists(tsv_name):
        # avoid overwrite. if there is metrics.1.tsv, name it metrics.2.tsv...
        i = 1
        tsv_name = os.path.join(args.input, f'metrics.{i}.tsv')
        while os.path.exists(tsv_name):
            i += 1
            tsv_name = os.path.join(args.input, f'metrics.{i}.tsv')
    df.to_csv(tsv_name, sep='\t', index=False)
    print(f'Metrics saved at: {tsv_name}')
    print(df)