Stopping Criteria Not Working

#15
by asmittal - opened

Has anyone tried using stopping criteria in Mistral 0.2 its not stopping generation on the token provided in the stopping criteria.
I'm using this piece of code
class StoppingCriteriaSub(StoppingCriteria):

def __init__(self, stops = [], encounters=1):
    super().__init__()
    self.stops = [stop.to("cuda") for stop in stops]

def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
    for stop in self.stops:
        if torch.all((stop == input_ids[0][-len(stop):])).item():
            return True

    return False

stop_word = "\n"
stop_word_ids = tokenizer(stop_word, return_tensors='pt')['input_ids']
stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_word_ids)])

It should stop generating at "\n" but that's not working. Any thoughts?
output = model.generate(
input_ids=model_inputs['input_ids'], attention_mask=model_inputs['attention_mask'],
max_new_tokens=max_tokens,
num_beams=1,
early_stopping=True,
temperature=temperature,
stopping_criteria=stopping_criteria)

Make sure you check which token is the one that is generate, and which token you are using as a stop token. the \n token that is given by tokenizer(stop_word, return_tensors='pt')['input_ids'] is not the same as tokenizer.convert_tokens_to_ids(stop_word) because the tokenizer always adds a prefix space.

@ArthurZ But this same code works with Falcon7b Instruct model

tokenizer.convert_tokens_to_ids(stop_word) this just gives zero as output for every token

stop_list = [" \n\nAnswer:", " \n", " \n\n"]
stop_token_ids = [tokenizer(x,  return_tensors='pt', add_special_tokens=False)['input_ids'] for x in stop_list]
stop_token_ids = [torch.LongTensor(x).to(device) for x in stop_token_ids]
    
class StopOnTokens(StoppingCriteria):
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        for stop_ids in stop_token_ids:
            if torch.eq(input_ids[0][-len(stop_ids[0])+1:], stop_ids[0][1:]).all():
                return True
        return False

stopping_criteria = StoppingCriteriaList([StopOnTokens()])

Sign up or log in to comment