h2ogpt-chatbot / stopping.py
trhacknon's picture
Duplicate from h2oai/h2ogpt-chatbot
d7185d6
import traceback
from queue import Queue
from threading import Thread
import collections.abc
import torch
from transformers import StoppingCriteria
class StoppingCriteriaSub(StoppingCriteria):
def __init__(self, stops=[], encounters=[]):
super().__init__()
assert len(stops) % len(encounters) == 0, "Number of stops and encounters must match"
self.encounters = encounters
self.stops = [stop.to("cuda") for stop in stops]
self.num_stops = [0] * len(stops)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
for stopi, stop in enumerate(self.stops):
if torch.all((stop == input_ids[0][-len(stop):])).item():
self.num_stops[stopi] += 1
if self.num_stops[stopi] >= self.encounters[stopi % len(self.encounters)]:
return True
# print("Tokens: %s" % input_ids[0].cpu().numpy(), flush=True)
# print("Stop Tokens: %s" % [x.cpu().numpy() for x in self.stops], flush=True)
return False
class Stream(StoppingCriteria):
"""
This class can be used to callback during generation. Keep
in mind for decoder-only type of transformers, this will include the initial prompted tokens.
Args:
func (`callable`):
A callable function to apply on first input in list every iteration of generation
"""
def __init__(self, func=None):
self.func = func
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
if self.func is not None:
# only consume first of multiple responses
self.func(input_ids[0])
return False
class CallbackToGenerator(collections.abc.Generator):
"""
A generator wrapper for a function that invokes a callback multiple times.
Calling `send` on the generator emits a value from one callback, and returns
the next.
Note this starts a background thread
"""
def __init__(self, func, *args, callback=None, **kwargs):
self.func = func
self.args = args
self.kwargs = kwargs
self.callback = callback
self._ready_queue = Queue(1)
self._done_queue = Queue(1)
self._done_holder = [False]
# local to avoid reference cycles
ready_queue = self._ready_queue
done_queue = self._done_queue
done_holder = self._done_holder
def val_callback(value):
done_queue.put((False, value))
cmd, val = ready_queue.get()
if cmd == 'send':
return val
elif cmd == 'throw':
raise val
else:
assert False # pragma: no cover
def thread_func():
while True:
cmd, val = ready_queue.get()
if cmd == 'send' and val is not None:
done_queue.put((True, TypeError("can't send non-None value to a just-started generator")))
continue
break
try:
if cmd == 'throw':
raise val
ret = func(callback=val_callback, **self.kwargs)
raise StopIteration(ret) if ret is not None else StopIteration
except BaseException as e:
done_holder[0] = True
done_queue.put((True, e))
self._thread = Thread(target=thread_func)
self._thread.start()
def _put(self, *args):
if self._done_holder[0]:
raise StopIteration
self._ready_queue.put(args)
is_exception, val = self._done_queue.get()
if is_exception:
try:
raise val
finally:
# prevent val's traceback containing a reference cycle
del val
else:
return val
def send(self, value):
return self._put('send', value)
def throw(self, exc):
return self._put('throw', exc)
def close(self):
try:
self.throw(GeneratorExit)
except StopIteration:
self._thread.join()
except GeneratorExit:
self._thread.join()
except BaseException:
self._thread.join()
raise
else:
# yielded again, can't clean up the thread
raise RuntimeError('Task with callback ignored GeneratorExit')
def __del__(self):
self.close()