# Copyright (c) Meta Platforms, Inc. and affiliates. # # This source code is licensed under the Chameleon License found in the # LICENSE file in the root directory of this source tree. import torch class StoppingCriteria: def __call__( self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs ) -> bool: raise NotImplementedError("StoppingCriteria needs to be subclassed") class StoppingCriteriaList(list): def __call__( self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs ) -> bool: return any(criteria(input_ids, scores, **kwargs) for criteria in self) class MaxLengthCriteria(StoppingCriteria): def __init__(self, max_length: int): self.max_length = max_length def __call__( self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs ) -> bool: cur_len = input_ids.shape[-1] return cur_len >= self.max_length class StopOnEOS(StoppingCriteria): def __init__(self, eos_id: int): self._eos_id = eos_id def __call__(self, input_ids: torch.LongTensor, _: torch.FloatTensor) -> bool: # input_ids.shape=[batch, seq_len] return (input_ids == self._eos_id).sum(dim=1).all() class StopOnEOSAfterBatchIndex(StoppingCriteria): def __init__(self, eos_id: int, batch_index: list[int]): self._eos_id = eos_id self.batch_index = torch.tensor(batch_index, dtype=torch.long).unsqueeze(1) def __call__(self, input_ids: torch.LongTensor, _: torch.FloatTensor) -> bool: # input_ids.shape=[batch, seq_len] eos_mask = input_ids == self._eos_id consider_eos_mask = ( torch.arange(input_ids.shape[1]).unsqueeze(0) >= self.batch_index ) valid_eos = eos_mask & consider_eos_mask return valid_eos.sum(dim=1).all()