CelebChat / run_eval.py
lhzstar
new commits
fe183af
raw history blame
No virus
3.62 kB
import itertools
import re
import spacy
import json
import evaluate
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModel
from unlimiformer import Unlimiformer, UnlimiformerArguments
import torch
from utils import *
from celebbot import CelebBot
QA_MODEL_ID = "google/flan-t5-xl"
SENTTR_MODEL_ID = "sentence-transformers/all-mpnet-base-v2"
celeb_names = ["Cate Blanchett", "David Beckham", "Emma Watson", "Lady Gaga", "Madonna", "Mark Zuckerberg"]
USE_UNLIMIFORMER = True
TOP_K = 16
celeb_data = get_celeb_data("data.json")
references = [val['answers'] for key, val in list(celeb_data.items()) if key in celeb_names]
references = list(itertools.chain.from_iterable(references))
predictions = []
device = 'cpu'
QA_tokenizer = AutoTokenizer.from_pretrained(QA_MODEL_ID)
QA_model = AutoModelForSeq2SeqLM.from_pretrained(QA_MODEL_ID)
if USE_UNLIMIFORMER:
defaults = UnlimiformerArguments()
unlimiformer_kwargs = {
'layer_begin': defaults.layer_begin,
'layer_end': defaults.layer_end,
'unlimiformer_head_num': defaults.unlimiformer_head_num,
'exclude_attention': defaults.unlimiformer_exclude,
'chunk_overlap': defaults.unlimiformer_chunk_overlap,
'model_encoder_max_len': defaults.unlimiformer_chunk_size,
'verbose': defaults.unlimiformer_verbose, 'tokenizer': QA_tokenizer,
'unlimiformer_training': defaults.unlimiformer_training,
'use_datastore': defaults.use_datastore,
'flat_index': defaults.flat_index,
'test_datastore': defaults.test_datastore,
'reconstruct_embeddings': defaults.reconstruct_embeddings,
'gpu_datastore': defaults.gpu_datastore,
'gpu_index': defaults.gpu_index
}
QA_model =Unlimiformer.convert_model(QA_model, **unlimiformer_kwargs).to(device)
else:
QA_model = QA_model.to(device)
sentTr_tokenizer = AutoTokenizer.from_pretrained(SENTTR_MODEL_ID)
sentTr_model = AutoModel.from_pretrained(SENTTR_MODEL_ID).to(device)
for celeb_name in celeb_names:
gender = celeb_data[celeb_name]["gender"]
if celeb_name == "Madonna":
name = "Madonna-American-singer-and-actress"
elif celeb_name == "Anne Hathaway":
name = "Anne-Hathaway-American-actress"
else:
name="-".join(celeb_name.split(" "))
knowledge = get_article(f"https://www.britannica.com/biography/{name}")
spacy_model = spacy.load("en_core_web_lg")
knowledge_sents = [i.text.strip() for i in spacy_model(knowledge).sents]
ai = CelebBot(celeb_name, gender, QA_tokenizer, QA_model, sentTr_tokenizer, sentTr_model, spacy_model, knowledge_sents, top_k=TOP_K)
for q in celeb_data[celeb_name]["questions"]:
ai.text = q
response = ai.question_answer()
print("response:", response)
predictions.append(response)
file = open('predictions.txt','w')
for prediction in predictions:
file.write(prediction+"\n")
file.close()
bleu = evaluate.load("bleu")
results = bleu.compute(predictions=predictions, references=references, max_order=4)
print(f"BLEU: {round(results['bleu'], 2)}")
meteor = evaluate.load("meteor")
results = meteor.compute(predictions=predictions, references=references)
print(f"METEOR: {round(results['meteor'], 2)}")
rouge = evaluate.load("rouge")
results = rouge.compute(predictions=predictions, references=references)
print(f"ROUGE: {round(results['rougeL'], 2)}")
bertscore = evaluate.load("bertscore")
results = bertscore.compute(predictions=predictions, references=references, rescale_with_baseline=True, lang="en")
print(f"F1: {round(sum(results['f1'])/len(results['f1']), 2)}")