import torch import re from torch import nn from torch.nn import functional as F from transformers import ( BartModel, ) from transformers.modeling_utils import PreTrainedModel from transformers.modeling_outputs import Seq2SeqLMOutput from transformers.generation_utils import top_k_top_p_filtering from typing import Iterable, List, Optional from transformers.file_utils import ModelOutput print("constrained.py") class BartConstrainedGen(PreTrainedModel): def __init__(self, config, tokenizer): super(BartConstrainedGen, self).__init__(config) self.config = config self.tokenizer = tokenizer self.transformer = BartModel.from_pretrained('facebook/bart-large') self.register_buffer("final_logits_bias", torch.zeros((1, self.transformer.shared.num_embeddings))) def resize_token_embeddings(self): old_num_tokens = self.transformer.shared.num_embeddings new_embeddings = self.transformer.resize_token_embeddings(len(self.tokenizer)) self.transformer.shared = new_embeddings self._resize_final_logits_bias(len(self.tokenizer), old_num_tokens) self.vocab_size = len(self.tokenizer) return new_embeddings def _resize_final_logits_bias(self, new_num_tokens: int, old_num_tokens: int) -> None: if new_num_tokens <= old_num_tokens: new_bias = self.final_logits_bias[:, :new_num_tokens] else: extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device) new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1) self.register_buffer("final_logits_bias", new_bias) def _init_weights(self, module): """ Initialize the weights """ if isinstance(module, (nn.Linear, nn.Embedding)): # Slightly different from the TF version which uses truncated_normal for initialization # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) elif isinstance(module, torch.nn.LayerNorm): # if use apex, this should be FusedLayerNorm module.bias.data.zero_() module.weight.data.fill_(1.0) if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() def get_encoder(self): return self.transformer.encoder def get_output_embeddings(self): # this method is needed for generation vocab_size, emb_size = self.transformer.shared.weight.shape lin_layer = nn.Linear(vocab_size, emb_size, bias=False) lin_layer.weight.data = self.transformer.shared.weight.data return lin_layer def prepare_inputs_for_generation( self, decoder_input_ids, past, attention_mask, use_cache, encoder_outputs, input_embeds, encoder_input_ids, **kwargs): return { "input_ids": encoder_input_ids, # encoder_outputs is defined. input_ids not needed "encoder_outputs": encoder_outputs, "past_key_values": past, "decoder_input_ids": decoder_input_ids, "attention_mask": attention_mask, "use_cache": use_cache, # change this to avoid caching (presumably for debugging) "input_embeds": input_embeds, } def adjust_logits_during_generation(self, logits, cur_len, max_length): if cur_len == 1 and self.config.force_bos_token_to_be_generated: self._force_token_ids_generation(logits, self.config.bos_token_id) elif cur_len == max_length - 1 and self.config.eos_token_id is not None: self._force_token_ids_generation(logits, self.config.eos_token_id) return logits def _force_token_ids_generation(self, scores, token_id) -> None: """force one of token_ids to be generated by setting prob of all other tokens to 0 (logprob=-float("inf"))""" scores[:, [x for x in range(self.config.vocab_size) if x != token_id]] = -float("inf") @staticmethod def _reorder_cache(past, beam_idx): reordered_past = [] for layer_past in past: # get the correct batch idx from decoder layer's batch dim for cross and self-attn layer_past_new = { attn_key: _reorder_buffer(attn_cache, beam_idx) for attn_key, attn_cache in layer_past.items() } reordered_past.append(layer_past_new) return reordered_past def convert_pointer_logits_to_lm_logits(self, pointer_logits, input_ids): ''' pointer_logits: (batch, seq_len, input_seq_len) input_ids: (batch, input_seq_len) lm_logits: (batch, seq_len, vocab_size) ''' batch_size = pointer_logits.size(0) seq_len = pointer_logits.size(1) input_seq_len = input_ids.size(1) lm_logits = torch.full((batch_size, seq_len, self.vocab_size), fill_value=-1000,dtype=pointer_logits.dtype).to(pointer_logits.device) # scatter may be technically incorrect for duplicate indexes, but not using it gets slow index = input_ids.unsqueeze(dim=1).expand_as(pointer_logits) lm_logits.scatter_(dim=2, index=index, src=pointer_logits) return lm_logits def remove_unseen(self, lm_logits, input_ids): # input_ids (batch, seq) seen_lm_logits = torch.full_like(lm_logits, fill_value=-1000).to(lm_logits.device) #(batch, seq, vocab) seen_vocab = set(input_ids.reshape(-1).tolist()) for i in range(self.transformer.vocab_size): if i in (seen_vocab): seen_lm_logits[:, :, i] = lm_logits[:, :, i] elif i == self.tokenizer.encode(' and', add_special_tokens=False)[0]: seen_lm_logits[:, :, i] = lm_logits[:, :, i] return seen_lm_logits def forward(self, input_ids, attention_mask=None, encoder_outputs=None, use_cache=False, past_key_values=None, decoder_input_ids=None, decoder_attention_mask=None, output_attentions=None, output_hidden_states=None, return_dict=None, input_embeds=None, task=-1): # generation if task==-1: outputs = self.transformer( input_ids=input_ids, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, use_cache=use_cache, encoder_outputs=encoder_outputs, past_key_values=past_key_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, input_embeds = input_embeds, return_dict=return_dict,) decoder_output = outputs[0] #(batch, seq_len, hidden_dim) if encoder_outputs==None: encoder_outputs = outputs[1] # (batch, input_seq_len, hidden_dim) # BaseModelOutput if return dict if input_embeds==None: # get encoder side embeddings input_embeds = self.transformer.encoder.embed_tokens(input_ids) * self.transformer.encoder.embed_scale #(batch, seq_len, input_seq_len) pointer_logits = torch.einsum('ijk,ilk->ijl', decoder_output, input_embeds) #(batch, seq_len, input_seq_len) lm_logits = self.convert_pointer_logits_to_lm_logits(pointer_logits, input_ids) masked_lm_loss = None if not return_dict: output = (lm_logits,) + outputs[1:] return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output return Seq2SeqLMOutput( loss=masked_lm_loss, logits=lm_logits, past_key_values=outputs.past_key_values, decoder_hidden_states=outputs.decoder_hidden_states, decoder_attentions=outputs.decoder_attentions, encoder_last_hidden_state=outputs.encoder_last_hidden_state, encoder_hidden_states=outputs.encoder_hidden_states, encoder_attentions=outputs.encoder_attentions, ) #training elif task==0: assert(decoder_input_ids!=None) y_ids = decoder_input_ids[:, :-1] labels = decoder_input_ids[:, 1:].clone() labels[labels== self.tokenizer.pad_token_id] = -100 # labels are just decoder_input_ids shifted to the right by 1 outputs = self.transformer( input_ids, attention_mask=attention_mask, decoder_input_ids=y_ids, decoder_attention_mask=decoder_attention_mask[:, :-1], use_cache=False, past_key_values=past_key_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict,) decoder_output = outputs[0] #(batch, seq_len, hidden_dim) encoder_output = outputs[1] # (batch, input_seq_len, hidden_dim) # lm_logits = F.linear(decoder_output, self.transformer.shared.weight, bias=self.final_logits_bias) # lm_logits = self.remove_unseen(lm_logits, input_ids) # get encoder side embeddings input_embeds = self.transformer.encoder.embed_tokens(input_ids) * self.transformer.encoder.embed_scale #(batch, seq_len, input_seq_len) pointer_logits = torch.einsum('ijk,ilk->ijl', decoder_output, input_embeds) #(batch, seq_len, input_seq_len) # decrease prob if neccesary lm_logits = self.convert_pointer_logits_to_lm_logits(pointer_logits, input_ids) outputs = (lm_logits,) + outputs[1:] # Add cache, hidden states and attention if they are here loss_fct = nn.CrossEntropyLoss() masked_lm_loss = loss_fct(lm_logits.view(-1, self.vocab_size), labels.view(-1)) outputs = (masked_lm_loss,) + outputs return outputs # # this is a simplified generate class for the pointer generator taken from https://github.com/huggingface/transformers/blob/v3.1.0/src/transformers/generation_utils.py @torch.no_grad() def generate( self, input_ids: Optional[torch.LongTensor] = None, max_length: Optional[int] = None, min_length: Optional[int] = None, do_sample: Optional[bool] = None, early_stopping: Optional[bool] = None, num_beams: Optional[int] = None, temperature: Optional[float] = None, top_k: Optional[int] = None, top_p: Optional[float] = None, repetition_penalty: Optional[float] = None, bad_words_ids: Optional[Iterable[int]] = None, bos_token_id: Optional[int] = None, pad_token_id: Optional[int] = None, eos_token_id: Optional[int] = None, length_penalty: Optional[float] = None, no_repeat_ngram_size: Optional[int] = None, num_return_sequences: Optional[int] = None, attention_mask: Optional[torch.LongTensor] = None, decoder_start_token_id: Optional[int] = None, use_cache: Optional[bool] = None, **model_kwargs ) -> torch.LongTensor: r""" Generates sequences for models with a language modeling head. The method currently supports greedy decoding, beam-search decoding, sampling with temperature, sampling with top-k or nucleus sampling. Adapted in part from `Facebook's XLM beam search code `__. Apart from :obj:`input_ids` and :obj:`attention_mask`, all the arguments below will default to the value of the attribute of the same name inside the :class:`~transformers.PretrainedConfig` of the model. The default values indicated are the default values of those config. Most of these parameters are explained in more detail in `this blog post `__. Parameters: input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): The sequence used as a prompt for the generation. If :obj:`None` the method initializes it as an empty :obj:`torch.LongTensor` of shape :obj:`(1,)`. max_length (:obj:`int`, `optional`, defaults to 20): The maximum length of the sequence to be generated. min_length (:obj:`int`, `optional`, defaults to 10): The minimum length of the sequence to be generated. do_sample (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether or not to use sampling ; use greedy decoding otherwise. early_stopping (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether to stop the beam search when at least ``num_beams`` sentences are finished per batch or not. num_beams (:obj:`int`, `optional`, defaults to 1): Number of beams for beam search. 1 means no beam search. temperature (:obj:`float`, `optional`, defaults tp 1.0): The value used to module the next token probabilities. top_k (:obj:`int`, `optional`, defaults to 50): The number of highest probability vocabulary tokens to keep for top-k-filtering. top_p (:obj:`float`, `optional`, defaults to 1.0): If set to float < 1, only the most probable tokens with probabilities that add up to ``top_p`` or higher are kept for generation. repetition_penalty (:obj:`float`, `optional`, defaults to 1.0): The parameter for repetition penalty. 1.0 means no penalty. See `this paper `__ for more details. pad_token_id (:obj:`int`, `optional`): The id of the `padding` token. bos_token_id (:obj:`int`, `optional`): The id of the `beginning-of-sequence` token. eos_token_id (:obj:`int`, `optional`): The id of the `end-of-sequence` token. length_penalty (:obj:`float`, `optional`, defaults to 1.0): Exponential penalty to the length. 1.0 means no penalty. Set to values < 1.0 in order to encourage the model to generate shorter sequences, to a value > 1.0 in order to encourage the model to produce longer sequences. no_repeat_ngram_size (:obj:`int`, `optional`, defaults to 0): If set to int > 0, all ngrams of that size can only occur once. bad_words_ids(:obj:`List[int]`, `optional`): List of token ids that are not allowed to be generated. In order to get the tokens of the words that should not appear in the generated text, use :obj:`tokenizer.encode(bad_word, add_prefix_space=True)`. num_return_sequences(:obj:`int`, `optional`, defaults to 1): The number of independently computed returned sequences for each element in the batch. attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): Mask to avoid performing attention on padding token indices. Mask values are in ``[0, 1]``, 1 for tokens that are not masked, and 0 for masked tokens. If not provided, will default to a tensor the same shape as :obj:`input_ids` that masks the pad token. `What are attention masks? <../glossary.html#attention-mask>`__ decoder_start_token_id (:obj:`int`, `optional`): If an encoder-decoder model starts decoding with a different token than `bos`, the id of that token. use_cache: (:obj:`bool`, `optional`, defaults to :obj:`True`): Whether or not the model should use the past last key/values attentions (if applicable to the model) to speed up decoding. model_kwargs: Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model. Return: :obj:`torch.LongTensor` of shape :obj:`(batch_size * num_return_sequences, sequence_length)`: The generated sequences. The second dimension (sequence_length) is either equal to :obj:`max_length` or shorter if all batches finished early due to the :obj:`eos_token_id`. Examples:: tokenizer = AutoTokenizer.from_pretrained('distilgpt2') # Initialize tokenizer model = AutoModelWithLMHead.from_pretrained('distilgpt2') # Download model and configuration from S3 and cache. outputs = model.generate(max_length=40) # do greedy decoding print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True))) tokenizer = AutoTokenizer.from_pretrained('openai-gpt') # Initialize tokenizer model = AutoModelWithLMHead.from_pretrained('openai-gpt') # Download model and configuration from S3 and cache. input_context = 'The dog' input_ids = tokenizer.encode(input_context, return_tensors='pt') # encode input context outputs = model.generate(input_ids=input_ids, num_beams=5, num_return_sequences=3, temperature=1.5) # generate 3 independent sequences using beam search decoding (5 beams) with sampling from initial context 'The dog' for i in range(3): # 3 output sequences were generated print('Generated {}: {}'.format(i, tokenizer.decode(outputs[i], skip_special_tokens=True))) tokenizer = AutoTokenizer.from_pretrained('distilgpt2') # Initialize tokenizer model = AutoModelWithLMHead.from_pretrained('distilgpt2') # Download model and configuration from S3 and cache. input_context = 'The dog' input_ids = tokenizer.encode(input_context, return_tensors='pt') # encode input context outputs = model.generate(input_ids=input_ids, max_length=40, temperature=0.7, num_return_sequences=3, do_sample=True) # generate 3 candidates using sampling for i in range(3): # 3 output sequences were generated print('Generated {}: {}'.format(i, tokenizer.decode(outputs[i], skip_special_tokens=True))) tokenizer = AutoTokenizer.from_pretrained('ctrl') # Initialize tokenizer model = AutoModelWithLMHead.from_pretrained('ctrl') # Download model and configuration from S3 and cache. input_context = 'Legal My neighbor is' # "Legal" is one of the control codes for ctrl input_ids = tokenizer.encode(input_context, return_tensors='pt') # encode input context outputs = model.generate(input_ids=input_ids, max_length=50, temperature=0.7, repetition_penalty=1.2) # generate sequences print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True))) tokenizer = AutoTokenizer.from_pretrained('gpt2') # Initialize tokenizer model = AutoModelWithLMHead.from_pretrained('gpt2') # Download model and configuration from S3 and cache. input_context = 'My cute dog' # "Legal" is one of the control codes for ctrl bad_words_ids = [tokenizer.encode(bad_word, add_prefix_space=True) for bad_word in ['idiot', 'stupid', 'shut up']] input_ids = tokenizer.encode(input_context, return_tensors='pt') # encode input context outputs = model.generate(input_ids=input_ids, max_length=100, do_sample=True, bad_words_ids=bad_words_ids) # generate sequences without allowing bad_words to be generated """ max_length = max_length if max_length is not None else self.config.max_length min_length = min_length if min_length is not None else self.config.min_length do_sample = do_sample if do_sample is not None else self.config.do_sample early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping use_cache = use_cache if use_cache is not None else self.config.use_cache num_beams = num_beams if num_beams is not None else self.config.num_beams temperature = temperature if temperature is not None else self.config.temperature top_k = top_k if top_k is not None else self.config.top_k top_p = top_p if top_p is not None else self.config.top_p repetition_penalty = repetition_penalty if repetition_penalty is not None else self.config.repetition_penalty bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty no_repeat_ngram_size = ( no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size ) bad_words_ids = bad_words_ids if bad_words_ids is not None else self.config.bad_words_ids num_return_sequences = ( num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences ) decoder_start_token_id = ( decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id ) if input_ids is not None: batch_size = input_ids.shape[0] # overriden by the input batch_size else: batch_size = 1 assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictly positive integer." assert isinstance(min_length, int) and min_length >= 0, "`min_length` should be a positive integer." assert isinstance(do_sample, bool), "`do_sample` should be a boolean." assert isinstance(early_stopping, bool), "`early_stopping` should be a boolean." assert isinstance(use_cache, bool), "`use_cache` should be a boolean." assert isinstance(num_beams, int) and num_beams > 0, "`num_beams` should be a strictly positive integer." assert temperature > 0, "`temperature` should be strictly positive." assert isinstance(top_k, int) and top_k >= 0, "`top_k` should be a positive integer." assert 0 <= top_p <= 1, "`top_p` should be between 0 and 1." assert repetition_penalty >= 1.0, "`repetition_penalty` should be >= 1." assert input_ids is not None or ( isinstance(bos_token_id, int) and bos_token_id >= 0 ), "If input_ids is not defined, `bos_token_id` should be a positive integer." assert pad_token_id is None or ( isinstance(pad_token_id, int) and (pad_token_id >= 0) ), "`pad_token_id` should be a positive integer." assert (eos_token_id is None) or ( isinstance(eos_token_id, int) and (eos_token_id >= 0) ), "`eos_token_id` should be a positive integer." assert length_penalty > 0, "`length_penalty` should be strictly positive." assert ( isinstance(no_repeat_ngram_size, int) and no_repeat_ngram_size >= 0 ), "`no_repeat_ngram_size` should be a positive integer." assert ( isinstance(num_return_sequences, int) and num_return_sequences > 0 ), "`num_return_sequences` should be a strictly positive integer." assert ( bad_words_ids is None or isinstance(bad_words_ids, list) and isinstance(bad_words_ids[0], list) ), "`bad_words_ids` is either `None` or a list of lists of tokens that should not be generated" if input_ids is None: assert isinstance(bos_token_id, int) and bos_token_id >= 0, ( "you should either supply a context to complete as `input_ids` input " "or a `bos_token_id` (integer >= 0) as a first token to start the generation." ) input_ids = torch.full( (batch_size, 1), bos_token_id, dtype=torch.long, device=next(self.parameters()).device, ) else: assert input_ids.dim() == 2, "Input prompt should be of shape (batch_size, sequence length)." # not allow to duplicate outputs when greedy decoding if do_sample is False: if num_beams == 1: # no_beam_search greedy generation conditions assert ( num_return_sequences == 1 ), "Greedy decoding will always produce the same output for num_beams == 1 and num_return_sequences > 1. Please set num_return_sequences = 1" else: # beam_search greedy generation conditions assert ( num_beams >= num_return_sequences ), "Greedy beam search decoding cannot return more sequences than it has beams. Please set num_beams >= num_return_sequences" # create attention mask if necessary # TODO (PVP): this should later be handled by the forward fn() in each model in the future see PR 3140 if (attention_mask is None) and (pad_token_id is not None) and (pad_token_id in input_ids): attention_mask = input_ids.ne(pad_token_id).long() elif attention_mask is None: attention_mask = input_ids.new_ones(input_ids.shape) # set pad_token_id to eos_token_id if not set. Important that this is done after # attention_mask is created if pad_token_id is None and eos_token_id is not None: pad_token_id = eos_token_id # current position and vocab size if hasattr(self.config, "vocab_size"): vocab_size = self.config.vocab_size elif ( self.config.is_encoder_decoder and hasattr(self.config, "decoder") and hasattr(self.config.decoder, "vocab_size") ): vocab_size = self.config.decoder.vocab_size # set effective batch size and effective batch multiplier according to do_sample if do_sample: effective_batch_size = batch_size * num_return_sequences effective_batch_mult = num_return_sequences else: effective_batch_size = batch_size effective_batch_mult = 1 if self.config.is_encoder_decoder: if decoder_start_token_id is None: # see if BOS token can be used for decoder_start_token_id if bos_token_id is not None: decoder_start_token_id = bos_token_id elif hasattr(self.config, "decoder") and hasattr(self.config.decoder, "bos_token_id"): decoder_start_token_id = self.config.decoder.bos_token_id else: raise ValueError( "decoder_start_token_id or bos_token_id has to be defined for encoder-decoder generation" ) assert hasattr(self, "get_encoder"), "{} should have a 'get_encoder' function defined".format(self) assert callable(self.get_encoder), "{} should be a method".format(self.get_encoder) # get encoder and store encoder outputs encoder = self.get_encoder() encoder_outputs: ModelOutput = encoder(input_ids, attention_mask=attention_mask, return_dict=True) input_embeds = encoder.embed_tokens(input_ids) * encoder.embed_scale # Expand input ids if num_beams > 1 or num_return_sequences > 1 if num_return_sequences > 1 or num_beams > 1: input_ids_len = input_ids.shape[-1] input_ids = input_ids.unsqueeze(1).expand(batch_size, effective_batch_mult * num_beams, input_ids_len) attention_mask = attention_mask.unsqueeze(1).expand( batch_size, effective_batch_mult * num_beams, input_ids_len ) input_ids = input_ids.contiguous().view( effective_batch_size * num_beams, input_ids_len ) # shape: (batch_size * num_return_sequences * num_beams, cur_len) attention_mask = attention_mask.contiguous().view( effective_batch_size * num_beams, input_ids_len ) # shape: (batch_size * num_return_sequences * num_beams, cur_len) encoder_input_ids = input_ids if self.config.is_encoder_decoder: # create empty decoder_input_ids input_ids = torch.full( (effective_batch_size * num_beams, 1), decoder_start_token_id, dtype=torch.long, device=next(self.parameters()).device, ) cur_len = 1 assert ( batch_size == encoder_outputs.last_hidden_state.shape[0] ), f"expected encoder_outputs.last_hidden_state to have 1st dimension bs={batch_size}, got {encoder_outputs.last_hidden_state.shape[0]} " # expand batch_idx to assign correct encoder output for expanded input_ids (due to num_beams > 1 and num_return_sequences > 1) expanded_batch_idxs = ( torch.arange(batch_size) .view(-1, 1) .repeat(1, num_beams * effective_batch_mult) .view(-1) .to(input_ids.device) ) # expand encoder_outputs encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.index_select( 0, expanded_batch_idxs ) # save encoder_outputs in `model_kwargs` model_kwargs["encoder_outputs"] = encoder_outputs model_kwargs["input_embeds"] = input_embeds model_kwargs["encoder_input_ids"] = encoder_input_ids else: cur_len = input_ids.shape[-1] assert ( cur_len < max_length ), f"The context has {cur_len} number of tokens, but `max_length` is only {max_length}. Please make sure that `max_length` is bigger than the number of tokens, by setting either `generate(max_length=...,...)` or `config.max_length = ...`" output = self._generate_no_beam_search( input_ids, cur_len=cur_len, max_length=max_length, min_length=min_length, do_sample=do_sample, temperature=temperature, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, no_repeat_ngram_size=no_repeat_ngram_size, bad_words_ids=bad_words_ids, pad_token_id=pad_token_id, eos_token_id=eos_token_id, batch_size=effective_batch_size, attention_mask=attention_mask, use_cache=use_cache, model_kwargs=model_kwargs, ) return output def _generate_no_beam_search( self, input_ids, cur_len, max_length, min_length, do_sample, temperature, top_k, top_p, repetition_penalty, no_repeat_ngram_size, bad_words_ids, pad_token_id, eos_token_id, batch_size, attention_mask, use_cache, model_kwargs, ): """Generate sequences for each example without beam search (num_beams == 1). All returned sequence are generated independantly. """ # length of generated sentences / unfinished sentences unfinished_sents = input_ids.new(batch_size).fill_(1) sent_lengths = input_ids.new(batch_size).fill_(max_length) past = None while cur_len < max_length: model_inputs = self.prepare_inputs_for_generation( input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **model_kwargs ) outputs = self(**model_inputs, return_dict=True) # calling forward here #outputs.logits (batch, seq_len, input_seq_len) next_token_logits = outputs.logits[:, -1, :] scores = self.postprocess_next_token_scores( scores=next_token_logits, input_ids=input_ids, no_repeat_ngram_size=no_repeat_ngram_size, bad_words_ids=bad_words_ids, cur_len=cur_len, min_length=min_length, max_length=max_length, eos_token_id=eos_token_id, repetition_penalty=repetition_penalty, batch_size=batch_size, num_beams=1, ) # if model has past, then set the past variable to speed up decoding if "past_key_values" in outputs: past = outputs.past_key_values elif "mems" in outputs: past = outputs.mems if do_sample: # Temperature (higher temperature => more likely to sample low probability tokens) if temperature != 1.0: scores = scores / temperature # Top-p/top-k filtering next_token_logscores = top_k_top_p_filtering(scores, top_k=top_k, top_p=top_p) # Sample probs = F.softmax(next_token_logscores, dim=-1) next_token = torch.multinomial(probs, num_samples=1).squeeze(1) else: # Greedy decoding next_token = torch.argmax(next_token_logits, dim=-1) # update generations and finished sentences if eos_token_id is not None: # pad finished sentences if eos_token_id exist tokens_to_add = next_token * unfinished_sents + (pad_token_id) * (1 - unfinished_sents) else: tokens_to_add = next_token # add token and increase length by one input_ids = torch.cat([input_ids, tokens_to_add.unsqueeze(-1)], dim=-1) cur_len = cur_len + 1 if eos_token_id is not None: eos_in_sents = tokens_to_add == eos_token_id # if sentence is unfinished and the token to add is eos, sent_lengths is filled with current length is_sents_unfinished_and_token_to_add_is_eos = unfinished_sents.mul(eos_in_sents.long()).bool() sent_lengths.masked_fill_(is_sents_unfinished_and_token_to_add_is_eos, cur_len) # unfinished_sents is set to zero if eos in sentence unfinished_sents.mul_((~eos_in_sents).long()) # stop when there is a in each sentence, or if we exceed the maximul length if unfinished_sents.max() == 0: break # extend attention_mask for new generated input if only decoder if self.config.is_encoder_decoder is False: attention_mask = torch.cat( [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 ) return input_ids # Function from `generation_utils.py` of Transformers library def postprocess_next_token_scores( self, scores, input_ids, no_repeat_ngram_size, bad_words_ids, cur_len, min_length, max_length, eos_token_id, repetition_penalty, batch_size, num_beams, ): # repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858) if repetition_penalty != 1.0: self.enforce_repetition_penalty_( scores, batch_size, num_beams, input_ids, repetition_penalty, ) # set eos token prob to zero if min_length is not reached if eos_token_id is not None and cur_len < min_length: scores[:, eos_token_id] = -float("inf") return scores