Spaces:
Runtime error
Runtime error
import os | |
import random | |
import time | |
import pickle | |
import math | |
from argparse import ArgumentParser | |
from typing import Iterable, List, Optional, Tuple | |
from tqdm import tqdm | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from transformers import AutoTokenizer, AutoModelWithLMHead, pipeline, set_seed, GPT2Tokenizer, GPT2Model, MarianTokenizer, MarianMTModel | |
from torch import Tensor | |
from data import Dataset | |
from model import Model | |
from util import save_checkpoint, ProgressMeter, AverageMeter, num_params | |
from constants import * | |
def main(args): | |
with open(args.dataset_info, 'rb') as rf: | |
dataset_info = pickle.load(rf) | |
tokenizer = MarianTokenizer.from_pretrained(args.model_string) | |
tokenizer.add_special_tokens({'pad_token': PAD_TOKEN}) | |
pad_id = tokenizer.encode(PAD_TOKEN)[0] | |
model = MarianMTModel.from_pretrained(args.model_string, return_dict=True).to(args.device) | |
model.eval() | |
checkpoint = torch.load(args.ckpt, map_location=args.device) | |
model_args = checkpoint['args'] | |
conditioning_model = Model(model_args, pad_id, len(dataset_info.index2word)) # no need to get the glove embeddings when reloading since they're saved in model ckpt anyway | |
conditioning_model.load_state_dict(checkpoint['state_dict']) | |
conditioning_model = conditioning_model.to(args.device) | |
conditioning_model.eval() | |
print("=> loaded checkpoint '{}' (epoch {})" | |
.format(args.ckpt, checkpoint['epoch'])) | |
print('num params', num_params(conditioning_model)) | |
while True: | |
results = predict_formality(model, | |
tokenizer, | |
conditioning_model, | |
[args.input_text], | |
dataset_info, | |
precondition_topk=args.precondition_topk, | |
do_sample=args.do_sample, | |
length_cutoff=args.length_cutoff, | |
condition_lambda=args.condition_lambda, | |
device=args.device) | |
print(results) | |
import pdb; pdb.set_trace() | |
def predict_formality(model, tokenizer, conditioning_model, input_text, dataset_info, precondition_topk=200, do_sample=False, length_cutoff=512, condition_lambda=1.0, device='cuda'): | |
with torch.no_grad(): | |
batch_size = len(input_text) | |
# assumes initially all same length. | |
# encode every x_i i \in [seq] word to respectable embedding | |
encoded_input = [tokenizer.encode(it, return_tensors='pt').to(device) for it in input_text] # batch x seq | |
encoded_input = torch.cat(encoded_input, dim=0) | |
input_ids = torch.LongTensor([[58100]]).to(device) | |
cur_len = 1 | |
max_length = length_cutoff | |
min_length = 0 | |
temperature = 1.0 | |
top_k = 50 | |
top_p = 1.0 | |
repetition_penalty = 1.0 | |
no_repeat_ngram_size = 0 | |
bad_words_ids = [[58100]] | |
pad_token_id = 58100 | |
eos_token_id = 0 | |
effective_batch_size = batch_size | |
attention_mask = encoded_input.new_ones(encoded_input.shape) | |
use_cache = True | |
model_specific_kwargs = {'encoder_outputs': model.get_encoder()(encoded_input, attention_mask=attention_mask)} | |
output = _generate_no_beam_search(model, | |
conditioning_model, | |
condition_lambda, | |
precondition_topk, | |
input_ids, | |
cur_len, | |
max_length, | |
min_length, | |
do_sample, | |
temperature, | |
top_k, | |
top_p, | |
repetition_penalty, | |
no_repeat_ngram_size, | |
bad_words_ids, | |
pad_token_id, | |
eos_token_id, | |
batch_size, | |
attention_mask, | |
use_cache, | |
model_specific_kwargs) | |
return [tokenizer.decode(s[1:]) for s in output] # 1: to delete the pad token | |
# hack of code from transformers/generation_utils.py | |
# to get our conditioning | |
def postprocess_next_token_scores( | |
model, | |
scores, | |
input_ids, | |
no_repeat_ngram_size, | |
bad_words_ids, | |
cur_len, | |
min_length, | |
max_length, | |
eos_token_id, | |
repetition_penalty, | |
batch_size, | |
num_beams, | |
): | |
# repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858) | |
if repetition_penalty != 1.0: | |
model.enforce_repetition_penalty_( | |
scores, | |
batch_size, | |
num_beams, | |
input_ids, | |
repetition_penalty, | |
) | |
# set eos token prob to zero if min_length is not reached | |
if eos_token_id is not None and cur_len < min_length: | |
scores[:, eos_token_id] = -float("inf") | |
if no_repeat_ngram_size > 0: | |
# calculate a list of banned tokens to prevent repetitively generating the same ngrams | |
num_batch_hypotheses = batch_size * num_beams | |
# from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345 | |
banned_batch_tokens = calc_banned_ngram_tokens( | |
input_ids, num_batch_hypotheses, no_repeat_ngram_size, cur_len | |
) | |
for i, banned_tokens in enumerate(banned_batch_tokens): | |
scores[i, banned_tokens] = -float("inf") | |
if bad_words_ids is not None: | |
# Exclude EOS token (already processed) | |
bad_words_ids = list(filter(lambda bad_token_seq: bad_token_seq != [eos_token_id], bad_words_ids)) | |
# calculate a list of banned tokens according to bad words | |
banned_tokens = calc_banned_bad_words_ids(input_ids.tolist(), bad_words_ids) | |
# Modify the scores in place by setting the banned tokens logits to `-inf` | |
set_scores_to_inf_for_banned_tokens(scores, banned_tokens) | |
return scores | |
def calc_banned_ngram_tokens(prev_input_ids: Tensor, num_hypos: int, no_repeat_ngram_size: int, cur_len: int) -> None: | |
"""Copied from fairseq for no_repeat_ngram in beam_search""" | |
if cur_len + 1 < no_repeat_ngram_size: | |
# return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet | |
return [[] for _ in range(num_hypos)] | |
generated_ngrams = [{} for _ in range(num_hypos)] | |
for idx in range(num_hypos): | |
gen_tokens = prev_input_ids[idx].tolist() | |
generated_ngram = generated_ngrams[idx] | |
for ngram in zip(*[gen_tokens[i:] for i in range(no_repeat_ngram_size)]): | |
prev_ngram_tuple = tuple(ngram[:-1]) | |
generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]] | |
def _get_generated_ngrams(hypo_idx): | |
# Before decoding the next token, prevent decoding of ngrams that have already appeared | |
start_idx = cur_len + 1 - no_repeat_ngram_size | |
ngram_idx = tuple(prev_input_ids[hypo_idx, start_idx:cur_len].tolist()) | |
return generated_ngrams[hypo_idx].get(ngram_idx, []) | |
banned_tokens = [_get_generated_ngrams(hypo_idx) for hypo_idx in range(num_hypos)] | |
return banned_tokens | |
def calc_banned_bad_words_ids(prev_input_ids: Iterable[int], bad_words_ids: Iterable[int]) -> Iterable[int]: | |
banned_tokens = [] | |
def _tokens_match(prev_tokens, tokens): | |
if len(tokens) == 0: | |
# if bad word tokens is just one token always ban it | |
return True | |
if len(tokens) > len(prev_tokens): | |
# if bad word tokens are longer than prev tokens they can't be equal | |
return False | |
if prev_tokens[-len(tokens) :] == tokens: | |
# if tokens match | |
return True | |
else: | |
return False | |
for prev_input_ids_slice in prev_input_ids: | |
banned_tokens_slice = [] | |
for banned_token_seq in bad_words_ids: | |
assert len(banned_token_seq) > 0, "Banned words token sequences {} cannot have an empty list".format( | |
bad_words_ids | |
) | |
if _tokens_match(prev_input_ids_slice, banned_token_seq[:-1]) is False: | |
# if tokens do not match continue | |
continue | |
banned_tokens_slice.append(banned_token_seq[-1]) | |
banned_tokens.append(banned_tokens_slice) | |
return banned_tokens | |
def set_scores_to_inf_for_banned_tokens(scores: torch.Tensor, banned_tokens: List[List[int]]) -> None: | |
"""Modifies the scores in place by setting the banned token positions to `-inf`. Banned token is expected to be | |
a list of list of banned tokens to ban in the format [[batch index, vocabulary position],...] | |
Args: | |
scores: logits distribution of shape (batch size, vocabulary size) | |
banned_tokens: list of list of tokens to ban of length (batch_size) | |
""" | |
banned_mask_list = [] | |
for idx, batch_banned_tokens in enumerate(banned_tokens): | |
for token in batch_banned_tokens: | |
banned_mask_list.append([idx, token]) | |
if not banned_mask_list: | |
return | |
banned_mask = torch.LongTensor(banned_mask_list) | |
indices = torch.ones(len(banned_mask)) | |
# A sparse tensor is generated from a list of coordinates: [[0, 1], [0, 2], [2, 0]]. A conversion to dense tensor generates: | |
# [ 0 1 1 ] | |
# [ 0 0 0 ] | |
# [ 1 0 0 ] | |
banned_mask = torch.sparse.LongTensor(banned_mask.t(), indices, scores.size()).to(scores.device).to_dense().bool() | |
scores.masked_fill_(banned_mask, -float("inf")) | |
def _generate_no_beam_search( | |
model, | |
conditioning_model, | |
condition_lambda, | |
precondition_topk, | |
input_ids, | |
cur_len, | |
max_length, | |
min_length, | |
do_sample, | |
temperature, | |
top_k, | |
top_p, | |
repetition_penalty, | |
no_repeat_ngram_size, | |
bad_words_ids, | |
pad_token_id, | |
eos_token_id, | |
batch_size, | |
attention_mask, | |
use_cache, | |
model_kwargs, | |
): | |
"""Generate sequences for each example without beam search (num_beams == 1). | |
All returned sequence are generated independantly. | |
""" | |
# length of generated sentences / unfinished sentences | |
unfinished_sents = input_ids.new(batch_size).fill_(1) | |
sent_lengths = input_ids.new(batch_size).fill_(max_length) | |
past = None | |
while cur_len < max_length: | |
model_inputs = model.prepare_inputs_for_generation( | |
input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **model_kwargs | |
) | |
outputs = model(**model_inputs, return_dict=True) | |
next_token_logits = outputs.logits[:, -1, :] | |
# scores = model.postprocess_next_token_scores( | |
# scores=next_token_logits, | |
# input_ids=input_ids, | |
# no_repeat_ngram_size=no_repeat_ngram_size, | |
# bad_words_ids=bad_words_ids, | |
# cur_len=cur_len, | |
# min_length=min_length, | |
# max_length=max_length, | |
# eos_token_id=eos_token_id, | |
# repetition_penalty=repetition_penalty, | |
# batch_size=batch_size, | |
# num_beams=1, | |
# ) | |
scores = postprocess_next_token_scores( | |
model=model, | |
scores=next_token_logits, | |
input_ids=input_ids, | |
no_repeat_ngram_size=no_repeat_ngram_size, | |
bad_words_ids=bad_words_ids, | |
cur_len=cur_len, | |
min_length=min_length, | |
max_length=max_length, | |
eos_token_id=eos_token_id, | |
repetition_penalty=repetition_penalty, | |
batch_size=batch_size, | |
num_beams=1, | |
) | |
# if model has past, then set the past variable to speed up decoding | |
if "past_key_values" in outputs: | |
past = outputs.past_key_values | |
elif "mems" in outputs: | |
past = outputs.mems | |
top_logits, top_indices = scores.topk(precondition_topk, dim=1) # batch x topk | |
tplus1_candidates = torch.cat([input_ids.unsqueeze(1).expand(-1, precondition_topk, -1), top_indices.unsqueeze(2)], dim=2)[:, :, 1:] # batch x topk x seq+1, with pad dropped | |
expanded_lengths = torch.LongTensor([[cur_len for _ in range(precondition_topk)] for _ in range(batch_size)]).to(scores.device) | |
if condition_lambda == 0: | |
condition_logits = torch.zeros_like(top_logits).float() | |
else: | |
condition_logits = conditioning_model(tplus1_candidates.flatten(0, 1), # batch*topk x seq+1 | |
expanded_lengths.flatten(0, 1), # batch*topk | |
None, | |
None, | |
None) | |
condition_logits = condition_logits.view(batch_size, precondition_topk, -1)[:, :, -1] # batch x topk of last formality pred | |
condition_logits = condition_logits - torch.log(1 + torch.exp(condition_logits)) # get correct log probs | |
# condition_logits = - torch.log(1 + torch.exp(condition_logits)) # for informal | |
full_logits = top_logits + condition_lambda * condition_logits | |
if do_sample: | |
raise NotImplementedError | |
else: | |
# Greedy decoding | |
next_token = top_indices[torch.arange(batch_size).to(top_indices.device), torch.argmax(full_logits, dim=-1)] | |
# if do_sample: | |
# # Temperature (higher temperature => more likely to sample low probability tokens) | |
# if temperature != 1.0: | |
# scores = scores / temperature | |
# # Top-p/top-k filtering | |
# next_token_logscores = top_k_top_p_filtering(scores, top_k=top_k, top_p=top_p) | |
# # Sample | |
# probs = F.softmax(next_token_logscores, dim=-1) | |
# next_token = torch.multinomial(probs, num_samples=1).squeeze(1) | |
# else: | |
# # Greedy decoding | |
# next_token = torch.argmax(next_token_logits, dim=-1) | |
# update generations and finished sentences | |
if eos_token_id is not None: | |
# pad finished sentences if eos_token_id exist | |
tokens_to_add = next_token * unfinished_sents + (pad_token_id) * (1 - unfinished_sents) | |
else: | |
tokens_to_add = next_token | |
# add token and increase length by one | |
input_ids = torch.cat([input_ids, tokens_to_add.unsqueeze(-1)], dim=-1) | |
cur_len = cur_len + 1 | |
if eos_token_id is not None: | |
eos_in_sents = tokens_to_add == eos_token_id | |
# if sentence is unfinished and the token to add is eos, sent_lengths is filled with current length | |
is_sents_unfinished_and_token_to_add_is_eos = unfinished_sents.mul(eos_in_sents.long()).bool() | |
sent_lengths.masked_fill_(is_sents_unfinished_and_token_to_add_is_eos, cur_len) | |
# unfinished_sents is set to zero if eos in sentence | |
unfinished_sents.mul_((~eos_in_sents).long()) | |
# stop when there is a </s> in each sentence, or if we exceed the maximul length | |
if unfinished_sents.max() == 0: | |
break | |
# extend attention_mask for new generated input if only decoder | |
if model.config.is_encoder_decoder is False: | |
attention_mask = torch.cat( | |
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 | |
) | |
return input_ids | |
if __name__=='__main__': | |
parser = ArgumentParser() | |
# DATA | |
parser.add_argument('--ckpt', type=str, required=True) | |
parser.add_argument('--dataset_info', type=str, required=True, help='saved dataset info') | |
parser.add_argument('--model_string', type=str, default='Helsinki-NLP/opus-mt-es-en') | |
parser.add_argument('--input_text', type=str, default=None, required=True, help='text to run pred on') | |
parser.add_argument('--precondition_topk', type=int, default=200, help='consider top k outputs from gpt at each step before conditioning and re-pruning') | |
parser.add_argument('--do_sample', action='store_true', default=False, help='sample instead of greedy') | |
parser.add_argument('--condition_lambda', type=float, default=1.0, help='lambda weight on conditioning model') | |
parser.add_argument('--length_cutoff', type=int, default=512, help='max length') | |
parser.add_argument('--seed', type=int, default=1, help='random seed') | |
parser.add_argument('--device', type=str, default='cuda', choices=['cpu', 'cuda']) | |
parser.add_argument('--debug', action='store_true', default=False) | |
args = parser.parse_args() | |
random.seed(args.seed) | |
np.random.seed(args.seed) | |
torch.manual_seed(args.seed) | |
main(args) | |