oleksandrfluxon commited on
Commit
42198c1
1 Parent(s): ed817dc

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +8 -8
pipeline.py CHANGED
@@ -6,14 +6,6 @@ from accelerate.utils import get_balanced_memory
6
  from transformers import BitsAndBytesConfig, StoppingCriteria, StoppingCriteriaList
7
  from typing import Dict, List, Any
8
 
9
- # define custom stopping criteria object
10
- class StopOnTokens(StoppingCriteria):
11
- def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
12
- for stop_ids in stop_token_ids:
13
- if torch.eq(input_ids[0][-len(stop_ids):], stop_ids).all():
14
- return True
15
- return False
16
-
17
  class PreTrainedPipeline():
18
  def __init__(self, path=""):
19
  path = "oleksandrfluxon/mpt-7b-instruct-evaluate"
@@ -44,6 +36,14 @@ class PreTrainedPipeline():
44
  stop_token_ids = [torch.LongTensor(x).to(device) for x in stop_token_ids]
45
  print("===> stop_token_ids", stop_token_ids)
46
 
 
 
 
 
 
 
 
 
47
  stopping_criteria = StoppingCriteriaList([StopOnTokens()])
48
 
49
  self.pipeline = transformers.pipeline(
 
6
  from transformers import BitsAndBytesConfig, StoppingCriteria, StoppingCriteriaList
7
  from typing import Dict, List, Any
8
 
 
 
 
 
 
 
 
 
9
  class PreTrainedPipeline():
10
  def __init__(self, path=""):
11
  path = "oleksandrfluxon/mpt-7b-instruct-evaluate"
 
36
  stop_token_ids = [torch.LongTensor(x).to(device) for x in stop_token_ids]
37
  print("===> stop_token_ids", stop_token_ids)
38
 
39
+ # define custom stopping criteria object
40
+ class StopOnTokens(StoppingCriteria):
41
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
42
+ for stop_ids in stop_token_ids:
43
+ if torch.eq(input_ids[0][-len(stop_ids):], stop_ids).all():
44
+ return True
45
+ return False
46
+
47
  stopping_criteria = StoppingCriteriaList([StopOnTokens()])
48
 
49
  self.pipeline = transformers.pipeline(