import os import random import time import pickle import math from argparse import ArgumentParser from typing import Iterable, List, Optional, Tuple from tqdm import tqdm import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from transformers import AutoTokenizer, AutoModelWithLMHead from torch import Tensor from fudge.data import Dataset from fudge.model import Model from fudge.util import num_params from fudge.constants import * tokenizer = AutoTokenizer.from_pretrained('google/pegasus-xsum') classifier_tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-mpnet-base-v2') def main(args): with open(args.dataset_info, 'rb') as rf: dataset_info = pickle.load(rf) article_content = """Australian actor Guy Pearce will return for the iconic soap Neighbours finale on August 1 to reprise his role as Mike Young. Guy, 54, played the troubled Mike from 1986 to 1989, and is now set to make a comeback on the show after 33 years, Metro.co.uk reports. The star's character arcs explored the implications of domestic abuse, student-teacher relationships and dealing with loss of loved ones. Speaking to Metro.co.uk, Guy said: 'It is very exciting and surreal at the same time being back on set again, however it feels like coming home. 'It's where it all started for me professionally. I've been asked to come back on occasions over the years and wondered if it was the right thing to do, but once I knew the show was finishing, I knew I had to do it.'He added that there is 'nothing like being here all together again' , even though he's had a chance to catch-up with other cast members.""" tokenizer.add_special_tokens({'pad_token': PAD_TOKEN}) pad_id = tokenizer.encode(PAD_TOKEN)[0] #For loading Clickbait summarizer model = AutoModelWithLMHead.from_pretrained(args.model_string, return_dict=True).to(args.device) model.eval() checkpoint = torch.load(args.ckpt, map_location=args.device) model_args = checkpoint['args'] conditioning_model = Model(model_args, pad_id, len(dataset_info.index2word)) # no need to get the glove embeddings when reloading since they're saved in model ckpt anyway conditioning_model.load_state_dict(checkpoint['state_dict']) conditioning_model = conditioning_model.to(args.device) conditioning_model.eval() print("=> loaded checkpoint '{}' (epoch {})" .format(args.ckpt, checkpoint['epoch'])) print('num params', num_params(conditioning_model)) while True: results = generate_clickbait(model, tokenizer, conditioning_model, [args.input_text], dataset_info, precondition_topk=args.precondition_topk, do_sample=args.do_sample, length_cutoff=args.length_cutoff, condition_lambda=args.condition_lambda, article_content=article_content, device=args.device) # print(results) import pdb; pdb.set_trace() def generate_clickbait(model, tokenizer, conditioning_model, input_text, dataset_info, precondition_topk, length_cutoff, condition_lambda=1.0, article_content=None, device='cuda'): with torch.no_grad(): batch_size = len(input_text) # encoded_input_article = [tokenizer.encode(article_content, return_tensors='pt',add_special_tokens=False).to(device)] # batch x seq max_input_length = 512 encoded_input_article = tokenizer(article_content, return_tensors='pt',add_special_tokens=False, max_length = max_input_length).to(device) # batch x seq # encoded_input_article = torch.cat(encoded_input_article, dim=0) # attention_mask = encoded_input_article.new_ones(encoded_input_article.shape).to(device) # CHANGE=ko encoded_input = tokenizer('', return_tensors='pt',add_special_tokens=False).to(device) # batch x seq # encoded_input = tokenizer(''+ input_text[0], return_tensors='pt',add_special_tokens=False).to(device) # batch x seq # encoded_input = torch.cat(encoded_input, dim=0) encoded_input = encoded_input['input_ids'] lengths = torch.LongTensor([encoded_input.shape[1]]).to(device) # lengths = 1 past = None use_cache = True # CHANGE # model_kwargs = {'encoder_outputs': model.get_encoder()(encoded_input_article, attention_mask=attention_mask)} model_kwargs = {'encoder_outputs': model.get_encoder()(input_ids=encoded_input_article['input_ids'], attention_mask=encoded_input_article['attention_mask'], return_dict=True, output_attentions=False, output_hidden_states=False), } while lengths.max() < length_cutoff: model_inputs = model.prepare_inputs_for_generation( input_ids = encoded_input_article['input_ids'], decoder_input_ids=encoded_input, # past=past, attention_mask=encoded_input_article['attention_mask'], use_cache=use_cache, **model_kwargs ) outputs = model(**model_inputs, return_dict=True) logits = outputs.logits[:, -1, :] if "past_key_values" in outputs: model_kwargs["past"] = outputs.past_key_values # logits = model(encoded_input)[0][:, -1, :] # batch x vocab top_logits, top_indices = logits.topk(precondition_topk, dim=1) # batch x topk new_input_candidates = torch.cat([encoded_input.unsqueeze(1).expand(-1, precondition_topk, -1), top_indices.unsqueeze(2)], dim=2) # batch x topk x seq+1 expanded_lengths = (lengths + 1).unsqueeze(1).expand(batch_size, precondition_topk) # batch x topk if condition_lambda == 0: condition_logits = torch.zeros_like(top_logits).float() condition_logits = condition_logits.view(batch_size, precondition_topk, -1) # batch x topk x N else: decoded_outputs = tokenizer.batch_decode(new_input_candidates.view(-1, new_input_candidates.size(-1)), clean_up_tokenization_spaces=False) resulting_tokenization = classifier_tokenizer(decoded_outputs, add_special_tokens=False, padding='longest') encoded_with_classifier = resulting_tokenization['input_ids'] attention_mask = torch.tensor(resulting_tokenization['attention_mask']).to(model.device) tplus1_candidates_classifier = torch.tensor(encoded_with_classifier).view(batch_size, precondition_topk, -1).to(model.device) condition_logits = conditioning_model(tplus1_candidates_classifier.flatten(0, 1), # batch*topk x seq+1 expanded_lengths.flatten(0, 1), # batch*topk None, None, None, attention_mask=attention_mask ) condition_logits = condition_logits.view(batch_size, precondition_topk, -1) # batch x topk x N condition_logits = condition_logits - torch.log(1 + torch.exp(condition_logits)) # get correct log probs condition_logits = torch.mean(condition_logits, dim=2) full_logits = top_logits + condition_logits * condition_lambda # batch x topk post_logits, post_indices = full_logits.topk(precondition_topk, dim=1) post_probs = F.softmax(post_logits, dim=1) # index_into_top_indices = post_indices[torch.arange(batch_size).to(post_indices.device), torch.multinomial(post_probs, 1).flatten()] # batch index_into_top_indices = post_indices[:, torch.multinomial(post_probs, 1).flatten()] # batch # next_indices = top_indices[torch.arange(batch_size).to(top_indices.device), index_into_top_indices] # batch next_indices = top_indices[:, index_into_top_indices] # batch # encoded_input = torch.cat([encoded_input, next_indices.unsqueeze(1)], dim=1) # batch x seq+1 encoded_input = torch.cat([encoded_input, next_indices.squeeze(1)], dim=1) lengths = lengths + 1 # batch # print(tokenizer.decode(encoded_input[0], add_special_tokens=False)) return [tokenizer.decode(s) for s in encoded_input] if __name__=='__main__': parser = ArgumentParser() # DATA parser.add_argument('--ckpt', type=str, required=True) parser.add_argument('--dataset_info', type=str, required=True, help='saved dataset info') parser.add_argument('--model_string', type=str, default='Helsinki-NLP/opus-mt-es-en') parser.add_argument('--in_file', type=str, default=None, required=True, help='text to run pred on') parser.add_argument('--precondition_topk', type=int, default=200, help='consider top k outputs from text generation at each step before conditioning and re-pruning') parser.add_argument('--do_sample', action='store_true', default=False, help='sample instead of greedy') parser.add_argument('--condition_lambda', type=float, default=1.0, help='lambda weight on conditioning model') parser.add_argument('--length_cutoff', type=int, default=512, help='max length') parser.add_argument('--seed', type=int, default=1, help='random seed') parser.add_argument('--device', type=str, default='cuda', choices=['cpu', 'cuda']) parser.add_argument('--debug', action='store_true', default=False) args = parser.parse_args() random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) main(args)