Stopping Criteria Not Working
Has anyone tried using stopping criteria in Mistral 0.2 its not stopping generation on the token provided in the stopping criteria.
I'm using this piece of code
class StoppingCriteriaSub(StoppingCriteria):
def __init__(self, stops = [], encounters=1):
super().__init__()
self.stops = [stop.to("cuda") for stop in stops]
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
for stop in self.stops:
if torch.all((stop == input_ids[0][-len(stop):])).item():
return True
return False
stop_word = "\n"
stop_word_ids = tokenizer(stop_word, return_tensors='pt')['input_ids']
stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_word_ids)])
It should stop generating at "\n" but that's not working. Any thoughts?
output = model.generate(
input_ids=model_inputs['input_ids'], attention_mask=model_inputs['attention_mask'],
max_new_tokens=max_tokens,
num_beams=1,
early_stopping=True,
temperature=temperature,
stopping_criteria=stopping_criteria)
Make sure you check which token is the one that is generate, and which token you are using as a stop token. the \n
token that is given by tokenizer(stop_word, return_tensors='pt')['input_ids']
is not the same as tokenizer.convert_tokens_to_ids(stop_word)
because the tokenizer always adds a prefix space.
tokenizer.convert_tokens_to_ids(stop_word) this just gives zero as output for every token
stop_list = [" \n\nAnswer:", " \n", " \n\n"]
stop_token_ids = [tokenizer(x, return_tensors='pt', add_special_tokens=False)['input_ids'] for x in stop_list]
stop_token_ids = [torch.LongTensor(x).to(device) for x in stop_token_ids]
class StopOnTokens(StoppingCriteria):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
for stop_ids in stop_token_ids:
if torch.eq(input_ids[0][-len(stop_ids[0])+1:], stop_ids[0][1:]).all():
return True
return False
stopping_criteria = StoppingCriteriaList([StopOnTokens()])