import traceback from queue import Queue from threading import Thread import collections.abc import torch from transformers import StoppingCriteria class StoppingCriteriaSub(StoppingCriteria): def __init__(self, stops=[], encounters=[], device="cuda"): super().__init__() assert len(stops) % len(encounters) == 0, "Number of stops and encounters must match" self.encounters = encounters self.stops = [stop.to(device) for stop in stops] self.num_stops = [0] * len(stops) def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: for stopi, stop in enumerate(self.stops): if torch.all((stop == input_ids[0][-len(stop):])).item(): self.num_stops[stopi] += 1 if self.num_stops[stopi] >= self.encounters[stopi % len(self.encounters)]: return True # print("Tokens: %s" % input_ids[0].cpu().numpy(), flush=True) # print("Stop Tokens: %s" % [x.cpu().numpy() for x in self.stops], flush=True) return False