Spaces:
Runtime error
Runtime error
from evaluate import Evaluator, ALL_TASKS | |
from baselines import * | |
from alignscore.inference import Inferencer | |
import time | |
import json | |
import os | |
from argparse import ArgumentParser | |
SAVE_ALL_TABLES = True | |
SAVE_AND_PRINT_TIMER = False | |
class Timer(): | |
def __init__(self) -> None: | |
self.t0 = time.time() | |
self.save_path = 'exp_results/time.json' | |
def finish(self, display_name): | |
t1 = time.time() | |
time_pass = t1 - self.t0 | |
if SAVE_AND_PRINT_TIMER: | |
print(f"Evalautor {display_name} finished in {time_pass} secs.") | |
with open(self.save_path, 'a', encoding='utf8') as f: | |
json.dump({display_name: time_pass}, f) | |
f.write('\n') | |
def eval_ctc(model_type, tasks=ALL_TASKS): | |
ctc_scorer = CTCScorer(model_type) | |
evaluator = Evaluator(eval_tasks=tasks, align_func=ctc_scorer.score, save_all_tables=SAVE_ALL_TABLES) | |
evaluator.result_save_name = f"baselines/CTC-{model_type}" | |
timer = Timer() | |
evaluator.evaluate() | |
timer.finish(f"CTC-{model_type}") | |
def eval_simcse(model_type, device, tasks=ALL_TASKS): | |
simcse_scorer = SimCSEScorer(model_type, device) | |
evaluator = Evaluator(eval_tasks=tasks, align_func=simcse_scorer.score, save_all_tables=SAVE_ALL_TABLES) | |
evaluator.result_save_name = f"baselines/{model_type.split('/')[-1]}_f" | |
timer = Timer() | |
evaluator.evaluate() | |
timer.finish(f"{model_type.split('/')[-1]}_f") | |
def eval_bleurt(checkpoint, tasks=ALL_TASKS): | |
bleurt_scorer = BleurtScorer(checkpoint) | |
evaluator = Evaluator(eval_tasks=tasks, align_func=bleurt_scorer.scorer, save_all_tables=SAVE_ALL_TABLES) | |
evaluator.result_save_name = f"baselines/BLEURT" | |
timer = Timer() | |
evaluator.evaluate() | |
timer.finish(f"BLEURT") | |
def eval_bertscore(model_type, device, batch_size, tasks=ALL_TASKS): | |
bertscore_scorer = BertScoreScorer(model_type=model_type, metric='f1', device=device, batch_size=batch_size) | |
evaluator = Evaluator(eval_tasks=tasks, align_func=bertscore_scorer.scorer, save_all_tables=SAVE_ALL_TABLES) | |
evaluator.result_save_name = f"baselines/bertscore_{model_type.replace('/', '-')}_f" | |
timer = Timer() | |
evaluator.evaluate() | |
timer.finish(f"bertscore_{model_type.replace('/', '-')}_f") | |
def eval_bartscore(checkpoint, device, tasks=ALL_TASKS): | |
bartscore_scorer = BartScoreScorer(checkpoint, device) | |
evaluator = Evaluator(eval_tasks=tasks, align_func=bartscore_scorer.scorer, save_all_tables=SAVE_ALL_TABLES) | |
evaluator.result_save_name = f"baselines/bartscore-{checkpoint.replace('/','-')}" | |
timer = Timer() | |
evaluator.evaluate() | |
timer.finish(f"bartscore-{checkpoint.replace('/','-')}") | |
### Below are Baselines for SummaC | |
def eval_mnli(model="roberta-large-mnli", device='cuda:0', tasks=ALL_TASKS): | |
mnli_scorer = MNLIScorer(model=model, device=device) | |
evaluator = Evaluator(eval_tasks=tasks, align_func=mnli_scorer.scorer, save_all_tables=SAVE_ALL_TABLES) | |
evaluator.result_save_name = f"baselines/mnli-{model}" | |
timer = Timer() | |
evaluator.evaluate() | |
timer.finish(f"mnli-{model}") | |
def eval_ner(tasks=ALL_TASKS): | |
ner_scorer = NERScorer() | |
evaluator = Evaluator(eval_tasks=tasks, align_func=ner_scorer.scorer, save_all_tables=SAVE_ALL_TABLES) | |
evaluator.result_save_name = f"baselines/NER" | |
timer = Timer() | |
evaluator.evaluate() | |
timer.finish(f"NER") | |
def eval_unieval(tasks=ALL_TASKS, device='cuda:0'): | |
unieval = UniEvalScorer(task='fact', device=device) | |
evaluator = Evaluator(eval_tasks=tasks, align_func=unieval.scorer, save_all_tables=SAVE_ALL_TABLES) | |
evaluator.result_save_name = f"baselines/UniEval" | |
timer = Timer() | |
evaluator.evaluate() | |
timer.finish(f"UniEval") | |
def eval_feqa(tasks=ALL_TASKS): | |
feqa = FEQAScorer() | |
evaluator = Evaluator(eval_tasks=tasks, align_func=feqa.scorer, save_all_tables=SAVE_ALL_TABLES) | |
evaluator.result_save_name = f"baselines/FEQA" | |
timer = Timer() | |
evaluator.evaluate() | |
timer.finish(f"FEQA") | |
def eval_questeval(tasks=ALL_TASKS): | |
questeval = QuestEvalScorer() | |
evaluator = Evaluator(eval_tasks=tasks, align_func=questeval.scorer, save_all_tables=SAVE_ALL_TABLES) | |
evaluator.result_save_name = f"baselines/QuestEval" | |
timer = Timer() | |
evaluator.evaluate() | |
timer.finish(f"QuestEval") | |
def eval_qafacteval(tasks=ALL_TASKS, device='cuda:0'): | |
import os, sys | |
warning("using conda env qaeval!!!") | |
qafacteval = QAFactEvalScorer(device=device, model_folder=os.path.abspath('../BaselineForNLGEval/QAFactEval/models')) | |
evaluator = Evaluator(eval_tasks=tasks, align_func=qafacteval.scorer, save_all_tables=SAVE_ALL_TABLES) | |
evaluator.result_save_name = f"baselines/QAFactEval" | |
evaluator.evaluate() | |
def eval_dae(tasks=ALL_TASKS, model_dir=None, device=0): | |
dae = DAEScorer(model_dir=model_dir, device=device) | |
evaluator = Evaluator(eval_tasks=tasks, align_func=dae.scorer, save_all_tables=SAVE_ALL_TABLES) | |
evaluator.result_save_name = f"baselines/DAE" | |
timer = Timer() | |
evaluator.evaluate() | |
timer.finish(f"DAE") | |
def eval_bleu(tasks=ALL_TASKS, n_grams=1): | |
bleu = BLEUScorer(n_grams=n_grams) | |
evaluator = Evaluator(eval_tasks=tasks, align_func=bleu.scorer, save_all_tables=SAVE_ALL_TABLES) | |
evaluator.result_save_name = f"baselines/BLEU-{n_grams}" | |
timer = Timer() | |
evaluator.evaluate() | |
timer.finish(f"BLEU-{n_grams}") | |
def eval_rouge(tasks=ALL_TASKS, rouge_type='1'): | |
rouge = ROUGEScorer(rouge_type=rouge_type) | |
evaluator = Evaluator(eval_tasks=tasks, align_func=rouge.scorer, save_all_tables=SAVE_ALL_TABLES) | |
evaluator.result_save_name = f"baselines/ROUGE-{rouge_type}" | |
timer = Timer() | |
evaluator.evaluate() | |
timer.finish(f"ROUGE-{rouge_type}") | |
def eval_factcc(script_path, test_data_path,result_path, tasks=ALL_TASKS): | |
factcc = FactCCScorer(script_path=script_path, test_data_path=test_data_path, result_path=result_path) | |
evaluator = Evaluator(eval_tasks=tasks, align_func=factcc.scorer, save_all_tables=SAVE_ALL_TABLES) | |
evaluator.result_save_name = f"baselines/FactCC" | |
timer = Timer() | |
evaluator.evaluate() | |
timer.finish(f"FactCC") | |
def eval_blanc(tasks=ALL_TASKS, device='cuda:0', batch_size=64): | |
blanc = BLANCScorer(device=device, batch_size=batch_size) | |
evaluator = Evaluator(eval_tasks=tasks, align_func=blanc.scorer, save_all_tables=SAVE_ALL_TABLES) | |
evaluator.result_save_name = f"baselines/BLANC" | |
timer = Timer() | |
evaluator.evaluate() | |
timer.finish(f"BLANC") | |
def eval_summac(tasks=ALL_TASKS, summac_type='conv', device='cuda:0'): | |
summac = SummaCScorer(summac_type=summac_type, device=device) | |
evaluator = Evaluator(eval_tasks=tasks, align_func=summac.scorer, save_all_tables=SAVE_ALL_TABLES) | |
evaluator.result_save_name = f"baselines/SummaC-{summac_type}" | |
timer = Timer() | |
evaluator.evaluate() | |
timer.finish(f"SummaC-{summac_type}") | |
def eval_align_nlg(ckpt_path, comment='', base_model='roberta-large', batch_size=32, device='cuda:0', tasks=ALL_TASKS, nlg_eval_mode='nli_sp'): | |
align = Inferencer(ckpt_path=ckpt_path, model=base_model, batch_size=batch_size, device=device) | |
if 'smart' in nlg_eval_mode: | |
align.smart_type = nlg_eval_mode | |
else: | |
align.nlg_eval_mode = nlg_eval_mode | |
evaluator = Evaluator(eval_tasks=tasks, align_func=align.nlg_eval, save_all_tables=SAVE_ALL_TABLES) | |
name = f'AlignScore-{nlg_eval_mode}-{base_model}' | |
if comment: | |
name += '_' + comment | |
evaluator.result_save_name = f"align_eval/{name}" | |
timer = Timer() | |
evaluator.evaluate() | |
timer.finish(name) | |
def eval_gptscore(api_key, gpt_model='davinci003', tasks=ALL_TASKS): | |
gptscore = GPTScoreScorer(api_key=api_key, gpt_model=gpt_model) | |
evaluator = Evaluator(eval_tasks=tasks, align_func=gptscore.scorer, is_save_all_tables=IS_SAVE_ALL_TABLES) | |
evaluator.result_save_name = f"nlg_eval_fact/baselines/GPTScore-{gpt_model}" | |
evaluator.evaluate() | |
def eval_chatgptluo2023(api_key, chat_model='gpt-3.5-turbo', tasks=['qags_cnndm']): | |
chatgpt = ChatGPTLuo2023Scorer(task=tasks, api_key=api_key, chat_model=chat_model) | |
evaluator = Evaluator(eval_tasks=tasks, align_func=chatgpt.scorer, is_save_all_tables=IS_SAVE_ALL_TABLES) | |
evaluator.result_save_name = f"nlg_eval_fact/baselines/ChatGPTLuo2023-{chat_model}" | |
evaluator.evaluate() | |
def eval_chatgptgao2023(api_key, chat_model='gpt-3.5-turbo', tasks=['qags_cnndm']): | |
chatgpt = ChatGPTGao2023Scorer(task=tasks, api_key=api_key, chat_model=chat_model) | |
evaluator = Evaluator(eval_tasks=tasks, align_func=chatgpt.scorer, is_save_all_tables=IS_SAVE_ALL_TABLES) | |
evaluator.result_save_name = f"nlg_eval_fact/baselines/ChatGPTGao2023-{chat_model}" | |
evaluator.evaluate() | |
def eval_chatgptyichen2023(api_key, chat_model='gpt-3.5-turbo', tasks=['qags_cnndm']): | |
chatgpt = ChatGPTYiChen2023Scorer(task=tasks, api_key=api_key, chat_model=chat_model) | |
evaluator = Evaluator(eval_tasks=tasks, align_func=chatgpt.scorer, is_save_all_tables=IS_SAVE_ALL_TABLES) | |
evaluator.result_save_name = f"nlg_eval_fact/baselines/ChatGPTYiChen2023-{chat_model}" | |
evaluator.evaluate() | |
def eval_chatgptshiqichen2023(api_key, chat_model='gpt-3.5-turbo', tasks=['qags_cnndm']): | |
chatgpt = ChatGPTShiqiChen2023Scorer(task=tasks, api_key=api_key, chat_model=chat_model) | |
evaluator = Evaluator(eval_tasks=tasks, align_func=chatgpt.scorer, is_save_all_tables=IS_SAVE_ALL_TABLES) | |
evaluator.result_save_name = f"nlg_eval_fact/baselines/ChatGPTShiqiChen2023-{chat_model}" | |
evaluator.evaluate() | |
def run_benchmarks(args, argugment_error): | |
os.makedirs('exp_results/baselines', exist_ok=True) | |
os.makedirs('exp_results/align_eval', exist_ok=True) | |
if args.alignscore: | |
if not all((args.alignscore_model, args.alignscore_ckpt, args.alignscore_eval_mode)): | |
argugment_error('--alignscore-model, --alignscore-model, and --alignscore-ckpt must be specified to run AlignScore') | |
eval_align_nlg( | |
nlg_eval_mode=args.alignscore_eval_mode, | |
ckpt_path=args.alignscore_ckpt, | |
base_model=args.alignscore_model, | |
device=args.device, tasks=args.tasks, | |
comment=args.alignscore_comment | |
) | |
if args.ctc: | |
if not args.ctc_type: | |
argugment_error('--ctc-type must be specified to run CTC baseline') | |
for type in args.ctc_type: | |
eval_ctc(type, tasks=args.tasks) | |
if args.simcse: | |
if not args.simcse_ckpt: | |
argugment_error('--simcse-ckpt must be specified to run SimCSE baseline') | |
for ckpt in args.simcse_ckpt: | |
eval_simcse(ckpt, device=args.device, tasks=args.tasks) | |
if args.bleurt: | |
if not args.bleurt_ckpt: | |
argugment_error('--bleurt-ckpt must be specified to run BLEURT baseline') | |
eval_bleurt(args.bleurt_ckpt, tasks=args.tasks) | |
if args.bertscore: | |
if not args.bertscore_ckpt or not args.bertscore_batch_size: | |
argugment_error('--bertscore-ckpt and --bertscore-batch-size must be specified to run BERTScore baseline') | |
for ckpt in args.bertscore_ckpt: | |
eval_bertscore(ckpt, device=args.device, tasks=args.tasks, batch_size=args.bertscore_batch_size) | |
if args.bartscore: | |
if not args.bartscore_ckpt: | |
argugment_error('--bartscore-ckpt must be specified to run BARTScore baseline') | |
for ckpt in args.bartscore_ckpt: | |
eval_bartscore(ckpt, device=args.device, tasks=args.tasks) | |
if args.mnli: | |
if not args.mnli_ckpt: | |
argugment_error('--mnli-ckpt must be specified to run MNLI baseline') | |
for ckpt in args.mnli_ckpt: | |
eval_mnli(model=ckpt, device=args.device, tasks=args.tasks) | |
if args.ner: | |
eval_ner(tasks=args.tasks) | |
if args.unieval: | |
eval_unieval(tasks=args.tasks, device=args.device) | |
if args.feqa: | |
eval_feqa(tasks=args.tasks) | |
if args.questeval: | |
eval_questeval(tasks=args.tasks) | |
if args.qafacteval: | |
eval_qafacteval(tasks=args.tasks) | |
if args.bleu: | |
if not args.bleu_ngram: | |
argugment_error('--bleu-ngram must be specified to run BLEU baseline') | |
for n in args.bleu_ngram: | |
eval_bleu(tasks=args.tasks, n_grams=n) | |
if args.rouge: | |
if not args.rouge_type: | |
argugment_error('--rouge-type must be specified to run ROUGE baseline') | |
for type in args.rouge_type: | |
eval_rouge(tasks=args.tasks, rouge_type=type) | |
if args.dae: | |
if not args.dae_ckpt: | |
argugment_error('--dae-ckpt must be specified to run DAE baseline') | |
eval_dae(tasks=args.tasks, model_dir=os.path.abspath(args.dae_ckpt)) | |
if args.factcc: | |
if not all((args.factcc_script, args.factcc_test_data, args.factcc_result_path)): | |
argugment_error('--factcc-script, --factcc-test-data, and --factcc-result-path must be specified to run FactCC baseline') | |
eval_factcc( | |
tasks=args.tasks, | |
script_path=os.path.abspath(args.factcc_script), | |
test_data_path=os.path.abspath(args.factcc_test_data), | |
result_path=os.path.abspath(args.factcc_result_path) | |
) | |
if args.blanc: | |
if not args.blanc_batch_size: | |
argugment_error('--blanc-batch-size must be specified to run BLANC baseline') | |
eval_blanc(tasks=args.tasks, device=args.device, batch_size=args.blanc_batch_size) | |
if args.summac: | |
if not args.summac_type: | |
argugment_error('--summac-type must be specified to run SummaC baseline') | |
for type in args.summac_type: | |
eval_summac(tasks=args.tasks, device=args.device, summac_type=type) | |
if __name__ == "__main__": | |
FACT_EVAL_TASKS = ['summac', 'true','xsumfaith', 'summeval', 'qags_xsum', 'qags_cnndm', 'newsroom', 'rank19', 'frank', 'samsum'] | |
parser = ArgumentParser() | |
parser.add_argument('--tasks', nargs='+', type=str, default=FACT_EVAL_TASKS, choices=FACT_EVAL_TASKS) | |
parser.add_argument('--device', type=str, default='cuda:0') | |
parser.add_argument('--timer', action='store_true', help='Time all metric runs') | |
alignscore_parser = parser.add_argument_group('AlignScore') | |
alignscore_parser.add_argument('--alignscore', action='store_true', help='Run AlignScore benchmark') | |
alignscore_parser.add_argument('--alignscore-model', type=str, choices=['roberta-base', 'roberta-large']) | |
alignscore_parser.add_argument('--alignscore-ckpt', type=str) | |
alignscore_parser.add_argument( | |
'--alignscore-eval-mode', | |
type=str, | |
choices=['bin', 'bin_sp', 'nli', 'nli_sp', 'reg', 'reg_sp', 'smart-n', 'smart-l'], | |
default='nli_sp' | |
) | |
alignscore_parser.add_argument('--alignscore-comment', type=str, default='') | |
ctc_parser = parser.add_argument_group('Baseline - CTC') | |
ctc_parser.add_argument('--ctc', action='store_true', help='Run CTC baseline') | |
ctc_parser.add_argument( | |
'--ctc-type', | |
nargs='*', | |
type=str, | |
choices=['D-cnndm', 'E-roberta', 'R-cnndm'], | |
default=['D-cnndm'] | |
) | |
simcse_parser = parser.add_argument_group('Baseline - SimCSE') | |
simcse_models = [ | |
'princeton-nlp/unsup-simcse-bert-base-uncased', | |
'princeton-nlp/unsup-simcse-bert-large-uncased', | |
'princeton-nlp/unsup-simcse-roberta-base', | |
'princeton-nlp/unsup-simcse-roberta-large', | |
'princeton-nlp/sup-simcse-bert-base-uncased', | |
'princeton-nlp/sup-simcse-bert-large-uncased', | |
'princeton-nlp/sup-simcse-roberta-base', | |
'princeton-nlp/sup-simcse-roberta-large' | |
] | |
simcse_parser.add_argument('--simcse', action='store_true', help='Run SimCSE baseline') | |
simcse_parser.add_argument( | |
'--simcse-ckpt', | |
nargs='*', | |
type=str, | |
choices=simcse_models, | |
default=['princeton-nlp/sup-simcse-roberta-large'] | |
) | |
bleurt_parser = parser.add_argument_group('Baseline - BLEURT') | |
bleurt_parser.add_argument('--bleurt', action='store_true', help='Run BLEURT baseline') | |
bleurt_parser.add_argument('--bleurt-ckpt', type=str) | |
bertscore_parser = parser.add_argument_group('Baseline - BERTScore') | |
bertscore_parser.add_argument('--bertscore', action='store_true', help='Run BERTScore baseline') | |
bertscore_parser.add_argument( | |
'--bertscore-ckpt', | |
nargs='*', | |
type=str, | |
default=['microsoft/deberta-xlarge-mnli'] | |
) | |
bertscore_parser.add_argument('--bertscore-batch-size', type=int, default=16) | |
bartscore_parser = parser.add_argument_group( | |
'Baseline - BARTScore', | |
description='Please clone https://github.com/neulab/BARTScore to baselines/BARTScore.' | |
) | |
bartscore_parser.add_argument('--bartscore', action='store_true', help='Run BARTScore baseline') | |
bartscore_parser.add_argument( | |
'--bartscore-ckpt', | |
type=str, | |
nargs='*', | |
default=['facebook/bart-large-cnn'] | |
) | |
mnli_parser = parser.add_argument_group('Baseline - MNLI') | |
mnli_parser.add_argument('--mnli', action='store_true', help='Run MNLI baseline') | |
mnli_parser.add_argument( | |
'--mnli-ckpt', | |
nargs='*', | |
type=str, | |
default=['roberta-large-mnli'] | |
) | |
ner_parser = parser.add_argument_group( | |
'Baseline - NER overlap', | |
description='Please clone https://github.com/tingofurro/summac to baselines/summac.' | |
) | |
ner_parser.add_argument('--ner', action='store_true', help='Run NER overlap baseline') | |
unieval_parser = parser.add_argument_group( | |
'Baseline - UniEval', | |
description='Please clone https://github.com/maszhongming/UniEval to baselines/UniEval.' | |
) | |
unieval_parser.add_argument('--unieval', action='store_true', help='Run UniEval baseline') | |
feqa_parser = parser.add_argument_group( | |
'Baseline - FEQA', | |
description='Please clone https://github.com/esdurmus/feqa to baselines/feqa' | |
) | |
feqa_parser.add_argument('--feqa', action='store_true', help='Run FEQA baseline') | |
questeval_parser = parser.add_argument_group( | |
'Baseline - QuestEval', | |
description='Please clone https://github.com/ThomasScialom/QuestEval to baselines/QuestEval.' | |
) | |
questeval_parser.add_argument('--questeval', action='store_true', help='Run QuestEval baseline') | |
qafacteval_parser = parser.add_argument_group( | |
'Baseline - QAFactEval', | |
description='Please clone https://github.com/salesforce/QAFactEval to baselines/QAFactEval.' | |
) | |
qafacteval_parser.add_argument('--qafacteval', action='store_true', help='Run QAFactEval baseline') | |
bleu_parser = parser.add_argument_group('Baseline - BLEU') | |
bleu_parser.add_argument('--bleu', action='store_true', help='Run BLEU baseline') | |
bleu_parser.add_argument( | |
'--bleu-ngram', | |
nargs='*', | |
type=int, | |
choices=[1, 2, 3, 4], | |
default=[1, 2, 3, 4] | |
) | |
rouge_parser = parser.add_argument_group('Baseline - ROUGE') | |
rouge_parser.add_argument('--rouge', action='store_true', help='Run ROUGE baseline') | |
rouge_parser.add_argument( | |
'--rouge-type', | |
nargs='*', | |
type=str, | |
choices=['1', '2', 'l'], | |
default=['1', '2', 'l'] | |
) | |
dae_parser = parser.add_argument_group('Baseline - DAE') | |
dae_parser.add_argument('--dae', action='store_true', help='Run DAE baseline') | |
dae_parser.add_argument('--dae-ckpt', type=str) | |
factcc_parser = parser.add_argument_group('Baseline - FactCC') | |
factcc_parser.add_argument('--factcc', action='store_true', help='Run FactCC baseline') | |
factcc_parser.add_argument('--factcc-script', type=str) | |
factcc_parser.add_argument('--factcc-test-data', type=str) | |
factcc_parser.add_argument('--factcc-result-path', type=str) | |
blanc_parser = parser.add_argument_group('Baseline - BLANC') | |
blanc_parser.add_argument('--blanc', action='store_true', help='Run BLANC baseline') | |
blanc_parser.add_argument('--blanc-batch-size', type=int, default=64) | |
summac_parser = parser.add_argument_group( | |
'Baseline - SummaC', | |
description='Please clone https://github.com/tingofurro/summac to baselines/summac.' | |
) | |
summac_parser.add_argument('--summac', action='store_true', help='Run SummaC baseline') | |
summac_parser.add_argument('--summac-type', nargs='*', type=str, choices=['conv', 'zs'], default=['conv', 'zs']) | |
args = parser.parse_args() | |
if args.timer: | |
SAVE_AND_PRINT_TIMER = True | |
def argugment_error(msg): | |
parser.error(msg) | |
run_benchmarks(args, argugment_error) | |