clickbaitonator / fudge /eval_poetry_metrics.py
dsvilarko
Initial commit
c4ebaf8
raw
history blame
6.41 kB
from argparse import ArgumentParser
import math
import string
from tqdm import tqdm
import numpy as np
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelWithLMHead, AutoModelForSequenceClassification
from poetry_util import is_iambic, perfect_rhyme_end, count_syllables
from constants import *
def conditional_perplexity(prefix, pred, tokenizer, model, device='cuda', sep_losses=False):
# calculate perplexity on pred only, conditioned on prefix
sentence = prefix + pred
sos_token = tokenizer.decode([0])
prefix_tensor_input = tokenizer.encode(sos_token + prefix.replace(EOT_TOKEN, ' ').strip(), return_tensors='pt').to(device)
full_tensor_input = tokenizer.encode(sos_token + sentence.replace(EOT_TOKEN, ' ').strip(), return_tensors='pt').to(device)
if sep_losses:
prefix_loss = model(prefix_tensor_input, labels=prefix_tensor_input)[0].sum()
full_loss = model(full_tensor_input, labels=full_tensor_input)[0].sum()
else:
prefix_loss = model(prefix_tensor_input, labels=prefix_tensor_input)[0] * (prefix_tensor_input.shape[1]-1) # neg log prob of prefix
full_loss = model(full_tensor_input, labels=full_tensor_input)[0] * (full_tensor_input.shape[1]-1) # neg log prob of full seq
pred_loss = full_loss - prefix_loss # neg log prob of preds given prefix
avg_pred_loss = pred_loss / (full_tensor_input.shape[1] - prefix_tensor_input.shape[1])
return math.exp(avg_pred_loss.item())
def grammaticality(sentences, tokenizer, model, device='cuda'):
with torch.no_grad():
total_good = 0
for sent in tqdm(sentences, total=len(sentences)):
good_prob = F.softmax(model(tokenizer.encode(sent, return_tensors='pt').to(device))[0].flatten(), dim=0)[1]
total_good += good_prob
return total_good / len(sentences) # avg probability of grammaticality according to model
def distinctness(sentences):
d1 = set()
d2 = set()
d3 = set()
total_words = 0
for sentence in sentences:
o = sentence.split(' ')
total_words += len(o)
d1.update(o)
for i in range(len(o) - 1):
d2.add(o[i] + '_' + o[i+1])
for i in range(len(o) - 2):
d3.add(o[i] + '_' + o[i+1] + '_' + o[i+2])
return len(d1) / total_words, len(d2) / total_words, len(d3) / total_words
if __name__=='__main__':
parser = ArgumentParser()
parser.add_argument('--pred_file', type=str)
parser.add_argument('--prefix_file', type=str)
parser.add_argument('--device', type=str, default='cuda', choices=['cpu', 'cuda'])
args = parser.parse_args()
preds = []
with open(args.pred_file, 'r') as rf:
for line in rf:
preds.append(line[:-1]) # drop \n but not beginning spaces if any
prefixes = []
with open(args.prefix_file, 'r') as rf:
for line in rf:
prefixes.append(line.strip())
assert len(prefixes) == len(preds)
rhymes = 0
iambic = 0
ten_syllables = 0
end = 0
diff_rhymes = 0
all_success = 0
total = len(prefixes)
for prefix, pred in zip(prefixes, preds):
if is_iambic(pred):
iambic += 1
if perfect_rhyme_end(prefix, pred):
rhymes += 1
if prefix.split()[-1].strip(string.punctuation) != pred.split()[-1].strip(string.punctuation):
diff_rhymes += 1
if count_syllables(pred) == 10:
ten_syllables += 1
if pred.strip()[-1] in PHRASE_ENDS:
end += 1
if is_iambic(pred) and perfect_rhyme_end(prefix, pred) and count_syllables(pred) == 10 and pred.strip()[-1] in PHRASE_ENDS:
all_success += 1
print('iambic', iambic, 'out of', total, ', frac', iambic / total)
print('rhymes', rhymes, 'out of', total, ', frac', rhymes / total)
print('end sentence', end, 'out of', total, ', frac', end / total)
print('10 syllables', ten_syllables, 'out of', total, ', frac', ten_syllables / total)
print('all success', all_success, 'out of', total, ', frac', all_success / total)
print('rhymes with diff word', diff_rhymes, 'out of', total, ', frac', diff_rhymes / total)
print('distinctness', distinctness(preds))
grammar_tokenizer = AutoTokenizer.from_pretrained('textattack/roberta-base-CoLA')
grammar_model = AutoModelForSequenceClassification.from_pretrained('textattack/roberta-base-CoLA').to(args.device)
grammar_model.eval()
print('grammaticality', grammaticality(preds, grammar_tokenizer, grammar_model, device=args.device))
perplexities = []
eval_tokenizer = AutoTokenizer.from_pretrained('transfo-xl-wt103')
eval_model = AutoModelWithLMHead.from_pretrained('transfo-xl-wt103').to(args.device)
eval_model.eval()
for prefix, pred in zip(prefixes, preds):
perplexities.append(conditional_perplexity(prefix, pred, eval_tokenizer, eval_model, device=args.device, sep_losses=True))
print('transformer xl perplexity', np.mean(perplexities), '+/-', np.std(perplexities))
perplexities = []
eval_tokenizer = AutoTokenizer.from_pretrained('openai-gpt')
eval_model = AutoModelWithLMHead.from_pretrained('openai-gpt').to(args.device)
eval_model.eval()
for prefix, pred in zip(prefixes, preds):
perplexities.append(conditional_perplexity(prefix, pred, eval_tokenizer, eval_model, device=args.device))
print('gpt perplexity', np.mean(perplexities), '+/-', np.std(perplexities))
# NOTE: uncomment this section with the path to the Shakespeare-finetuned GPT to evaluate this metric. it's in ckpt/poetry/gpt_finetune_shakespeare.pth.tar.
# eval_tokenizer = AutoTokenizer.from_pretrained('openai-gpt')
# eval_model = AutoModelWithLMHead.from_pretrained('openai-gpt').to(args.device)
# checkpoint = torch.load('***PATH_TO_SHAKESPEARE_FINETUNED_GPT***', map_location=args.device)
# mod_dict = {}
# for key in checkpoint['state_dict']:
# mod_dict[key.replace('classifier.', '')] = checkpoint['state_dict'][key]
# eval_model.load_state_dict(mod_dict)
# eval_model.eval()
# perplexities = []
# for prefix, pred in zip(prefixes, preds):
# perplexities.append(conditional_perplexity(prefix, pred, eval_tokenizer, eval_model, device=args.device))
# print('shakespeare finetuned perplexity', np.mean(perplexities), '+/-', np.std(perplexities))