nizar-sayad
commited on
Commit
•
84e98e5
1
Parent(s):
9b01ad2
Delete handler.py
Browse files- handler.py +0 -52
handler.py
DELETED
@@ -1,52 +0,0 @@
|
|
1 |
-
from typing import Dict, List, Any
|
2 |
-
from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria, StoppingCriteriaList,StoppingCriteria
|
3 |
-
import torch
|
4 |
-
|
5 |
-
|
6 |
-
class EndpointHandler():
|
7 |
-
def __init__(self, path="."):
|
8 |
-
# load model and processor from path
|
9 |
-
self.model = AutoModelForCausalLM.from_pretrained(path)
|
10 |
-
self.tokenizer = AutoTokenizer.from_pretrained(path)
|
11 |
-
|
12 |
-
# Create a stopping criteria class
|
13 |
-
class KeywordsStoppingCriteria(StoppingCriteria):
|
14 |
-
def __init__(self, keywords_ids: list, occurrences: int):
|
15 |
-
super().__init__()
|
16 |
-
self.keywords = keywords_ids
|
17 |
-
self.occurrences = occurrences
|
18 |
-
self.count = 0
|
19 |
-
|
20 |
-
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
21 |
-
if input_ids[0][-1] in self.keywords:
|
22 |
-
self.count += 1
|
23 |
-
if self.count == self.occurrences:
|
24 |
-
return True
|
25 |
-
return False
|
26 |
-
|
27 |
-
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
28 |
-
|
29 |
-
# process input
|
30 |
-
inputs = data["inputs"]
|
31 |
-
|
32 |
-
stop_words = ['.']
|
33 |
-
stop_ids = [self.tokenizer.encode(w)[1] for w in stop_words]
|
34 |
-
gen_outputs = []
|
35 |
-
gen_outputs_no_input = []
|
36 |
-
gen_input = self.tokenizer(input, return_tensors="pt")
|
37 |
-
for _ in range(5):
|
38 |
-
stop_criteria = KeywordsStoppingCriteria(stop_ids, occurrences=2)
|
39 |
-
gen_output = self.model.generate(gen_input.input_ids, do_sample=True,
|
40 |
-
top_k=10,
|
41 |
-
top_p=0.95,
|
42 |
-
max_new_tokens=100,
|
43 |
-
penalty_alpha=0.6,
|
44 |
-
stopping_criteria=StoppingCriteriaList([stop_criteria])
|
45 |
-
)
|
46 |
-
gen_outputs.append(gen_output)
|
47 |
-
gen_outputs_no_input.append(gen_output[0][len(gen_input.input_ids[0]):])
|
48 |
-
|
49 |
-
gen_outputs_decoded = [self.tokenizer.decode(gen_output[0], skip_special_tokens=True) for gen_output in gen_outputs]
|
50 |
-
gen_outputs_no_input_decoded = [self.tokenizer.decode(gen_output_no_input, skip_special_tokens=True) for gen_output_no_input in gen_outputs_no_input]
|
51 |
-
|
52 |
-
return {"gen_outputs_decoded": gen_outputs_decoded, "gen_outputs_no_input_decoded": gen_outputs_no_input_decoded}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|