import logging from typing import List import torch from transformers import ( LogitsProcessor, ) class StopAfterTokenIsGenerated(LogitsProcessor): def __init__(self, stops: List[torch.tensor], eos_token_id: int): super().__init__() self.stops = stops self.eos_token_id = eos_token_id logging.info(f"Stopping criteria words ids: {self.stops}") self.first_batch = True def __call__( self, input_ids: torch.LongTensor, scores: torch.FloatTensor ) -> torch.FloatTensor: """ Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids) scores (`torch.FloatTensor` of shape `(batch_size, config.vocab_size)`): Prediction scores of a language modeling head. These can be logits for each vocabulary when not using beam search or log softmax for each vocabulary token when using beam search Return: `torch.FloatTensor` of shape `(batch_size, config.vocab_size)`: The processed prediction scores. """ if self.first_batch: self.first_batch = False return scores for seq_no, seq in enumerate(input_ids): # logging.info(seq_no) for stop in self.stops: stop = stop.to(device=seq.device, dtype=seq.dtype) if ( len(seq) >= len(stop) and torch.all((stop == seq[-len(stop) :])).item() ): scores[seq_no, :] = -float("inf") scores[seq_no, self.eos_token_id] = 0 # logging.info(f"Stopping criteria found: {stop}") break return scores def reset(self): self.first_batch = True