aaabiao commited on
Commit
25468b9
1 Parent(s): 8f70d80

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -0
app.py CHANGED
@@ -13,6 +13,18 @@ from transformers import (
13
  TextIteratorStreamer,
14
  )
15
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  MAX_MAX_NEW_TOKENS = 2048
17
  DEFAULT_MAX_NEW_TOKENS = 1024
18
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
 
13
  TextIteratorStreamer,
14
  )
15
 
16
+ class StoppingCriteriaSub(StoppingCriteria):
17
+ def __init__(self, stops = [], encounters=1):
18
+ super().__init__()
19
+ self.stops = [stop.to("cuda") for stop in stops]
20
+
21
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
22
+ last_token = input_ids[0][-1]
23
+ for stop in self.stops:
24
+ if tokenizer.decode(stop) == tokenizer.decode(last_token):
25
+ return True
26
+ return False
27
+
28
  MAX_MAX_NEW_TOKENS = 2048
29
  DEFAULT_MAX_NEW_TOKENS = 1024
30
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))