Spaces:
Runtime error
Runtime error
import os | |
import random | |
import time | |
import pickle | |
import math | |
from argparse import ArgumentParser | |
from collections import defaultdict | |
import string | |
import csv | |
from tqdm import tqdm | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from transformers import AutoTokenizer, AutoModelWithLMHead, AutoModelForSequenceClassification | |
from data import Dataset | |
from model import Model | |
from util import save_checkpoint, ProgressMeter, AverageMeter, num_params, pad_mask | |
from predict import predict | |
from constants import * | |
def tw_topic_eval(sentences, category, tw_dir, cap=None): | |
# num matches of distinct words | |
words = [] | |
with open(os.path.join(tw_dir, category + '.txt'), 'r') as rf: | |
for line in rf: | |
words.append(line.strip().lower()) | |
num_match = 0 | |
for sent in sentences: | |
sent_match = 0 | |
sent = sent.strip().lower().split() | |
sent = [tok.strip(string.punctuation) for tok in sent] | |
for word in words: | |
if word in sent: | |
sent_match += 1 | |
if cap is None: | |
num_match += sent_match | |
else: | |
num_match += min(cap, sent_match) | |
return num_match | |
def perplexity(sentences, tokenizer, model, device='cuda'): | |
# calculate perplexity | |
with torch.no_grad(): | |
ppl = [] | |
sos_token = tokenizer.decode([0]) | |
for sentence in tqdm(sentences, total=len(sentences)): | |
full_tensor_input = tokenizer.encode(sos_token + sentence.replace(EOT_TOKEN, ' ').strip(), return_tensors='pt').to(device) | |
full_loss = model(full_tensor_input, labels=full_tensor_input)[0].mean() | |
ppl.append(torch.exp(full_loss).flatten().cpu().item()) | |
return np.mean(ppl), np.std(ppl) | |
def grammaticality(sentences, tokenizer, model, device='cuda'): | |
with torch.no_grad(): | |
total_good = 0 | |
for sent in tqdm(sentences, total=len(sentences)): | |
good_prob = F.softmax(model(tokenizer.encode(sent, return_tensors='pt').to(device))[0].flatten(), dim=0)[1] | |
total_good += good_prob | |
return total_good / len(sentences) # avg probability of grammaticality according to model | |
def distinctness(results): | |
d1, d2, d3 = defaultdict(lambda: set()), defaultdict(lambda: set()), defaultdict(lambda: set()) | |
total_words = defaultdict(lambda: 0) | |
for cw, outputs in results.items(): | |
for o in outputs: | |
o = o.replace(EOT_TOKEN, ' ').strip().split(' ') | |
o = [str(x) for x in o] | |
total_words[cw] += len(o) | |
d1[cw].update(o) | |
for i in range(len(o) - 1): | |
d2[cw].add(o[i] + ' ' + o[i+1]) | |
for i in range(len(o) - 2): | |
d3[cw].add(o[i] + ' ' + o[i+1] + ' ' + o[i+2]) | |
return_info = [] | |
avg_d1, avg_d2, avg_d3 = 0, 0, 0 | |
for cw in total_words.keys(): | |
return_info.append((cw, 'DISTINCTNESS', len(d1[cw]) / total_words[cw], len(d2[cw]) / total_words[cw], len(d3[cw]) / total_words[cw])) | |
avg_d1 += len(d1[cw]) / total_words[cw] | |
avg_d2 += len(d2[cw]) / total_words[cw] | |
avg_d3 += len(d3[cw]) / total_words[cw] | |
avg_d1, avg_d2, avg_d3 = avg_d1 / len(total_words.keys()), avg_d2 / len(total_words.keys()), avg_d3 / len(total_words.keys()) | |
return return_info, (avg_d1, avg_d2, avg_d3) | |
if __name__=='__main__': | |
parser = ArgumentParser() | |
parser.add_argument('--log_file', type=str, required=True, help='where to load results from') | |
parser.add_argument('--tw_dir', type=str, default='test_wordlists', help='test wordlists') | |
parser.add_argument('--batch_size', type=int, default=8, help='max samples at a time') | |
parser.add_argument('--cap_per_example', type=int, default=None, help='max matches to count per sentence') | |
parser.add_argument('--device', type=str, default='cuda', choices=['cpu', 'cuda']) | |
args = parser.parse_args() | |
tw_topic_match_c_total = 0 | |
category_totals_c = defaultdict(lambda:0) | |
results = defaultdict(lambda: []) | |
with open(args.log_file, 'r') as rf: | |
data = list(csv.DictReader(rf)) | |
for line in data: | |
results[line['category']].append(line['generation']) | |
all_c_sents = [] | |
for category, condition_results in results.items(): | |
tw_topic_match_c = tw_topic_eval(condition_results, category, args.tw_dir, cap=args.cap_per_example) | |
tw_topic_match_c_total += tw_topic_match_c | |
category_totals_c[category] += tw_topic_match_c | |
all_c_sents += condition_results | |
print('Test wordlist matches (divide by num outputs to get the Success metric):', tw_topic_match_c_total) | |
print('per category:', category_totals_c) | |
dist_info_by_category, dist_overall = distinctness(results) | |
print('Overall avg distinctness:', dist_overall) | |
print('per category:', dist_info_by_category) | |
grammar_tokenizer = AutoTokenizer.from_pretrained('textattack/roberta-base-CoLA') | |
grammar_model = AutoModelForSequenceClassification.from_pretrained('textattack/roberta-base-CoLA').to(args.device) | |
grammar_model.eval() | |
print('grammaticality:', grammaticality(all_c_sents, grammar_tokenizer, grammar_model, device=args.device)) | |
eval_tokenizer = AutoTokenizer.from_pretrained('openai-gpt') | |
eval_model = AutoModelWithLMHead.from_pretrained('openai-gpt').to(args.device) | |
eval_model.eval() | |
print('GPT perplexity:', perplexity(all_c_sents, eval_tokenizer, eval_model)) | |
eval_tokenizer = AutoTokenizer.from_pretrained('transfo-xl-wt103') | |
eval_model = AutoModelWithLMHead.from_pretrained('transfo-xl-wt103').to(args.device) | |
eval_model.eval() | |
print('TFXL perplexity:', perplexity(all_c_sents, eval_tokenizer, eval_model)) | |