nizar-sayad commited on
Commit
84e98e5
1 Parent(s): 9b01ad2

Delete handler.py

Browse files
Files changed (1) hide show
  1. handler.py +0 -52
handler.py DELETED
@@ -1,52 +0,0 @@
1
- from typing import Dict, List, Any
2
- from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria, StoppingCriteriaList,StoppingCriteria
3
- import torch
4
-
5
-
6
- class EndpointHandler():
7
- def __init__(self, path="."):
8
- # load model and processor from path
9
- self.model = AutoModelForCausalLM.from_pretrained(path)
10
- self.tokenizer = AutoTokenizer.from_pretrained(path)
11
-
12
- # Create a stopping criteria class
13
- class KeywordsStoppingCriteria(StoppingCriteria):
14
- def __init__(self, keywords_ids: list, occurrences: int):
15
- super().__init__()
16
- self.keywords = keywords_ids
17
- self.occurrences = occurrences
18
- self.count = 0
19
-
20
- def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
21
- if input_ids[0][-1] in self.keywords:
22
- self.count += 1
23
- if self.count == self.occurrences:
24
- return True
25
- return False
26
-
27
- def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
28
-
29
- # process input
30
- inputs = data["inputs"]
31
-
32
- stop_words = ['.']
33
- stop_ids = [self.tokenizer.encode(w)[1] for w in stop_words]
34
- gen_outputs = []
35
- gen_outputs_no_input = []
36
- gen_input = self.tokenizer(input, return_tensors="pt")
37
- for _ in range(5):
38
- stop_criteria = KeywordsStoppingCriteria(stop_ids, occurrences=2)
39
- gen_output = self.model.generate(gen_input.input_ids, do_sample=True,
40
- top_k=10,
41
- top_p=0.95,
42
- max_new_tokens=100,
43
- penalty_alpha=0.6,
44
- stopping_criteria=StoppingCriteriaList([stop_criteria])
45
- )
46
- gen_outputs.append(gen_output)
47
- gen_outputs_no_input.append(gen_output[0][len(gen_input.input_ids[0]):])
48
-
49
- gen_outputs_decoded = [self.tokenizer.decode(gen_output[0], skip_special_tokens=True) for gen_output in gen_outputs]
50
- gen_outputs_no_input_decoded = [self.tokenizer.decode(gen_output_no_input, skip_special_tokens=True) for gen_output_no_input in gen_outputs_no_input]
51
-
52
- return {"gen_outputs_decoded": gen_outputs_decoded, "gen_outputs_no_input_decoded": gen_outputs_no_input_decoded}