Dusan Svilarkovic
Adding Fudge
fc5ecba
raw
history blame
5.72 kB
import os
import random
import time
import pickle
import math
from argparse import ArgumentParser
from collections import defaultdict
import string
import csv
from tqdm import tqdm
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelWithLMHead, AutoModelForSequenceClassification
from data import Dataset
from model import Model
from util import save_checkpoint, ProgressMeter, AverageMeter, num_params, pad_mask
from predict import predict
from constants import *
def tw_topic_eval(sentences, category, tw_dir, cap=None):
# num matches of distinct words
words = []
with open(os.path.join(tw_dir, category + '.txt'), 'r') as rf:
for line in rf:
words.append(line.strip().lower())
num_match = 0
for sent in sentences:
sent_match = 0
sent = sent.strip().lower().split()
sent = [tok.strip(string.punctuation) for tok in sent]
for word in words:
if word in sent:
sent_match += 1
if cap is None:
num_match += sent_match
else:
num_match += min(cap, sent_match)
return num_match
def perplexity(sentences, tokenizer, model, device='cuda'):
# calculate perplexity
with torch.no_grad():
ppl = []
sos_token = tokenizer.decode([0])
for sentence in tqdm(sentences, total=len(sentences)):
full_tensor_input = tokenizer.encode(sos_token + sentence.replace(EOT_TOKEN, ' ').strip(), return_tensors='pt').to(device)
full_loss = model(full_tensor_input, labels=full_tensor_input)[0].mean()
ppl.append(torch.exp(full_loss).flatten().cpu().item())
return np.mean(ppl), np.std(ppl)
def grammaticality(sentences, tokenizer, model, device='cuda'):
with torch.no_grad():
total_good = 0
for sent in tqdm(sentences, total=len(sentences)):
good_prob = F.softmax(model(tokenizer.encode(sent, return_tensors='pt').to(device))[0].flatten(), dim=0)[1]
total_good += good_prob
return total_good / len(sentences) # avg probability of grammaticality according to model
def distinctness(results):
d1, d2, d3 = defaultdict(lambda: set()), defaultdict(lambda: set()), defaultdict(lambda: set())
total_words = defaultdict(lambda: 0)
for cw, outputs in results.items():
for o in outputs:
o = o.replace(EOT_TOKEN, ' ').strip().split(' ')
o = [str(x) for x in o]
total_words[cw] += len(o)
d1[cw].update(o)
for i in range(len(o) - 1):
d2[cw].add(o[i] + ' ' + o[i+1])
for i in range(len(o) - 2):
d3[cw].add(o[i] + ' ' + o[i+1] + ' ' + o[i+2])
return_info = []
avg_d1, avg_d2, avg_d3 = 0, 0, 0
for cw in total_words.keys():
return_info.append((cw, 'DISTINCTNESS', len(d1[cw]) / total_words[cw], len(d2[cw]) / total_words[cw], len(d3[cw]) / total_words[cw]))
avg_d1 += len(d1[cw]) / total_words[cw]
avg_d2 += len(d2[cw]) / total_words[cw]
avg_d3 += len(d3[cw]) / total_words[cw]
avg_d1, avg_d2, avg_d3 = avg_d1 / len(total_words.keys()), avg_d2 / len(total_words.keys()), avg_d3 / len(total_words.keys())
return return_info, (avg_d1, avg_d2, avg_d3)
if __name__=='__main__':
parser = ArgumentParser()
parser.add_argument('--log_file', type=str, required=True, help='where to load results from')
parser.add_argument('--tw_dir', type=str, default='test_wordlists', help='test wordlists')
parser.add_argument('--batch_size', type=int, default=8, help='max samples at a time')
parser.add_argument('--cap_per_example', type=int, default=None, help='max matches to count per sentence')
parser.add_argument('--device', type=str, default='cuda', choices=['cpu', 'cuda'])
args = parser.parse_args()
tw_topic_match_c_total = 0
category_totals_c = defaultdict(lambda:0)
results = defaultdict(lambda: [])
with open(args.log_file, 'r') as rf:
data = list(csv.DictReader(rf))
for line in data:
results[line['category']].append(line['generation'])
all_c_sents = []
for category, condition_results in results.items():
tw_topic_match_c = tw_topic_eval(condition_results, category, args.tw_dir, cap=args.cap_per_example)
tw_topic_match_c_total += tw_topic_match_c
category_totals_c[category] += tw_topic_match_c
all_c_sents += condition_results
print('Test wordlist matches (divide by num outputs to get the Success metric):', tw_topic_match_c_total)
print('per category:', category_totals_c)
dist_info_by_category, dist_overall = distinctness(results)
print('Overall avg distinctness:', dist_overall)
print('per category:', dist_info_by_category)
grammar_tokenizer = AutoTokenizer.from_pretrained('textattack/roberta-base-CoLA')
grammar_model = AutoModelForSequenceClassification.from_pretrained('textattack/roberta-base-CoLA').to(args.device)
grammar_model.eval()
print('grammaticality:', grammaticality(all_c_sents, grammar_tokenizer, grammar_model, device=args.device))
eval_tokenizer = AutoTokenizer.from_pretrained('openai-gpt')
eval_model = AutoModelWithLMHead.from_pretrained('openai-gpt').to(args.device)
eval_model.eval()
print('GPT perplexity:', perplexity(all_c_sents, eval_tokenizer, eval_model))
eval_tokenizer = AutoTokenizer.from_pretrained('transfo-xl-wt103')
eval_model = AutoModelWithLMHead.from_pretrained('transfo-xl-wt103').to(args.device)
eval_model.eval()
print('TFXL perplexity:', perplexity(all_c_sents, eval_tokenizer, eval_model))