clickbaitonator / fudge /predict_clickbait.py
Dusan's picture
Update fudge/predict_clickbait.py
4183d85
raw
history blame
10.4 kB
import os
import random
import time
import pickle
import math
from argparse import ArgumentParser
from typing import Iterable, List, Optional, Tuple
from tqdm import tqdm
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelWithLMHead
from torch import Tensor
from fudge.data import Dataset
from fudge.model import Model
from fudge.util import num_params
from fudge.constants import *
tokenizer = AutoTokenizer.from_pretrained('google/pegasus-xsum')
classifier_tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-mpnet-base-v2')
def main(args):
with open(args.dataset_info, 'rb') as rf:
dataset_info = pickle.load(rf)
article_content = """Australian actor Guy Pearce will return for the iconic soap Neighbours finale on August 1 to reprise his role as Mike Young.
Guy, 54, played the troubled Mike from 1986 to 1989, and is now set to make a comeback on the show after 33 years, Metro.co.uk reports.
The star's character arcs explored the implications of domestic abuse, student-teacher relationships and dealing with loss of loved ones.
Speaking to Metro.co.uk, Guy said: 'It is very exciting and surreal at the same time being back on set again, however it feels like coming home.
'It's where it all started for me professionally. I've been asked to come back on occasions over the years and wondered if it was the right thing
to do, but once I knew the show was finishing, I knew I had to do it.'He added that there is 'nothing like being here all together again'
, even though he's had a chance to catch-up with other cast members."""
tokenizer.add_special_tokens({'pad_token': PAD_TOKEN})
pad_id = tokenizer.encode(PAD_TOKEN)[0]
#For loading Clickbait summarizer
model = AutoModelWithLMHead.from_pretrained(args.model_string, return_dict=True).to(args.device)
model.eval()
checkpoint = torch.load(args.ckpt, map_location=args.device)
model_args = checkpoint['args']
conditioning_model = Model(model_args, pad_id, len(dataset_info.index2word)) # no need to get the glove embeddings when reloading since they're saved in model ckpt anyway
conditioning_model.load_state_dict(checkpoint['state_dict'])
conditioning_model = conditioning_model.to(args.device)
conditioning_model.eval()
print("=> loaded checkpoint '{}' (epoch {})"
.format(args.ckpt, checkpoint['epoch']))
print('num params', num_params(conditioning_model))
while True:
results = generate_clickbait(model,
tokenizer,
conditioning_model,
[args.input_text],
dataset_info,
precondition_topk=args.precondition_topk,
do_sample=args.do_sample,
length_cutoff=args.length_cutoff,
condition_lambda=args.condition_lambda,
article_content=article_content,
device=args.device)
# print(results)
import pdb; pdb.set_trace()
def generate_clickbait(model,
tokenizer,
conditioning_model,
input_text,
dataset_info,
precondition_topk,
length_cutoff,
condition_lambda=1.0,
article_content=None,
device='cuda'):
with torch.no_grad():
batch_size = len(input_text)
# encoded_input_article = [tokenizer.encode(article_content, return_tensors='pt',add_special_tokens=False).to(device)] # batch x seq
max_input_length = 512
encoded_input_article = tokenizer(article_content, return_tensors='pt',add_special_tokens=False, max_length = max_input_length).to(device) # batch x seq
# encoded_input_article = torch.cat(encoded_input_article, dim=0)
# attention_mask = encoded_input_article.new_ones(encoded_input_article.shape).to(device)
# CHANGE=ko
encoded_input = tokenizer('<pad>', return_tensors='pt',add_special_tokens=False).to(device) # batch x seq
# encoded_input = tokenizer('<pad>'+ input_text[0], return_tensors='pt',add_special_tokens=False).to(device) # batch x seq
# encoded_input = torch.cat(encoded_input, dim=0)
encoded_input = encoded_input['input_ids']
lengths = torch.LongTensor([encoded_input.shape[1]]).to(device)
# lengths = 1
past = None
use_cache = True
# CHANGE
# model_kwargs = {'encoder_outputs': model.get_encoder()(encoded_input_article, attention_mask=attention_mask)}
model_kwargs = {'encoder_outputs': model.get_encoder()(input_ids=encoded_input_article['input_ids'],
attention_mask=encoded_input_article['attention_mask'],
return_dict=True,
output_attentions=False,
output_hidden_states=False),
}
while lengths.max() < length_cutoff:
model_inputs = model.prepare_inputs_for_generation(
input_ids = encoded_input_article['input_ids'],
decoder_input_ids=encoded_input,
# past=past,
attention_mask=encoded_input_article['attention_mask'],
use_cache=use_cache,
**model_kwargs
)
outputs = model(**model_inputs, return_dict=True)
logits = outputs.logits[:, -1, :]
if "past_key_values" in outputs:
model_kwargs["past"] = outputs.past_key_values
# logits = model(encoded_input)[0][:, -1, :] # batch x vocab
top_logits, top_indices = logits.topk(precondition_topk, dim=1) # batch x topk
new_input_candidates = torch.cat([encoded_input.unsqueeze(1).expand(-1, precondition_topk, -1), top_indices.unsqueeze(2)], dim=2) # batch x topk x seq+1
expanded_lengths = (lengths + 1).unsqueeze(1).expand(batch_size, precondition_topk) # batch x topk
if condition_lambda == 0:
condition_logits = torch.zeros_like(top_logits).float()
condition_logits = condition_logits.view(batch_size, precondition_topk, -1) # batch x topk x N
else:
decoded_outputs = tokenizer.batch_decode(new_input_candidates.view(-1, new_input_candidates.size(-1)), clean_up_tokenization_spaces=False)
resulting_tokenization = classifier_tokenizer(decoded_outputs, add_special_tokens=False, padding='longest')
encoded_with_classifier = resulting_tokenization['input_ids']
attention_mask = torch.tensor(resulting_tokenization['attention_mask']).to(model.device)
tplus1_candidates_classifier = torch.tensor(encoded_with_classifier).view(batch_size, precondition_topk, -1).to(model.device)
condition_logits = conditioning_model(tplus1_candidates_classifier.flatten(0, 1), # batch*topk x seq+1
expanded_lengths.flatten(0, 1), # batch*topk
None,
None,
None,
attention_mask=attention_mask
)
condition_logits = condition_logits.view(batch_size, precondition_topk, -1) # batch x topk x N
condition_logits = condition_logits - torch.log(1 + torch.exp(condition_logits)) # get correct log probs
condition_logits = torch.mean(condition_logits, dim=2)
full_logits = top_logits + condition_logits * condition_lambda # batch x topk
post_logits, post_indices = full_logits.topk(precondition_topk, dim=1)
post_probs = F.softmax(post_logits, dim=1)
# index_into_top_indices = post_indices[torch.arange(batch_size).to(post_indices.device), torch.multinomial(post_probs, 1).flatten()] # batch
index_into_top_indices = post_indices[:, torch.multinomial(post_probs, 1).flatten()] # batch
# next_indices = top_indices[torch.arange(batch_size).to(top_indices.device), index_into_top_indices] # batch
next_indices = top_indices[:, index_into_top_indices] # batch
# encoded_input = torch.cat([encoded_input, next_indices.unsqueeze(1)], dim=1) # batch x seq+1
encoded_input = torch.cat([encoded_input, next_indices.squeeze(1)], dim=1)
lengths = lengths + 1 # batch
# print(tokenizer.decode(encoded_input[0], add_special_tokens=False))
return [tokenizer.decode(s) for s in encoded_input]
if __name__=='__main__':
parser = ArgumentParser()
# DATA
parser.add_argument('--ckpt', type=str, required=True)
parser.add_argument('--dataset_info', type=str, required=True, help='saved dataset info')
parser.add_argument('--model_string', type=str, default='Helsinki-NLP/opus-mt-es-en')
parser.add_argument('--in_file', type=str, default=None, required=True, help='text to run pred on')
parser.add_argument('--precondition_topk', type=int, default=200, help='consider top k outputs from text generation at each step before conditioning and re-pruning')
parser.add_argument('--do_sample', action='store_true', default=False, help='sample instead of greedy')
parser.add_argument('--condition_lambda', type=float, default=1.0, help='lambda weight on conditioning model')
parser.add_argument('--length_cutoff', type=int, default=512, help='max length')
parser.add_argument('--seed', type=int, default=1, help='random seed')
parser.add_argument('--device', type=str, default='cuda', choices=['cpu', 'cuda'])
parser.add_argument('--debug', action='store_true', default=False)
args = parser.parse_args()
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
main(args)