ClaudiaIoana550 commited on
Commit
614db54
1 Parent(s): 2cf1f8d

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +105 -0
handler.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from typing import Any, Dict, List
3
+ from langchain.llms import HuggingFacePipeline
4
+
5
+ import torch
6
+ import transformers
7
+ from transformers import AutoModelForCausalLM, AutoTokenizer
8
+
9
+ dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16
10
+
11
+ from transformers import (
12
+ StoppingCriteria,
13
+ StoppingCriteriaList,
14
+ pipeline,
15
+ )
16
+ from typing import List
17
+ import torch
18
+
19
+ class StopGenerationCriteria(StoppingCriteria):
20
+ def __init__(self, max_duplicate_sequences=3, max_repeated_words=2):
21
+ self.generated_sequences = set()
22
+ self.max_duplicate_sequences = max_duplicate_sequences
23
+ self.max_repeated_words = max_repeated_words
24
+
25
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
26
+ tokenizer=AutoTokenizer.from_pretrained("ClaudiaIoana550/try1_deploy_falcon", trust_remote_code=True)
27
+ generated_sequence = input_ids.tolist()
28
+
29
+ if len(generated_sequence[0]) >= 50:
30
+ sequen = generated_sequence[0][-30:]
31
+ s_mare = str(generated_sequence[0]).strip("[]")
32
+ s_mic = str(sequen).strip("[]")
33
+ count2 = 0
34
+ if s_mic in s_mare:
35
+ count2 = sum(1 for i in range(len(generated_sequence[0]) - len(sequen) + 1) if generated_sequence[0][i:i + len(sequen)] == sequen)
36
+ if count2 >= 2:
37
+ return True
38
+
39
+
40
+ generated_tokens = [tokenizer.decode(token_id) for token_id in input_ids[0]]
41
+ count = 1
42
+ prev_token = None
43
+ for token in generated_tokens:
44
+ if token == prev_token:
45
+ count += 1
46
+ if count > self.max_repeated_words:
47
+ return True
48
+ else:
49
+ count = 1
50
+ prev_token = token
51
+
52
+ if len(self.generated_sequences) >= self.max_duplicate_sequences:
53
+ return True
54
+
55
+ return False
56
+
57
+
58
+ # Example usage:
59
+ # Define the maximum number of duplicate sequences and repeated words
60
+ max_duplicate_sequences = 1
61
+ max_repeated_words = 2
62
+
63
+ # Create an instance of StopGenerationCriteria
64
+ stop_criteria = StopGenerationCriteria(max_duplicate_sequences, max_repeated_words)
65
+
66
+ # Add the custom stopping criteria to a StoppingCriteriaList
67
+ stopping_criteria = StoppingCriteriaList([stop_criteria])
68
+
69
+
70
+ class EndpointHandler:
71
+ def __init__(self, model_path=""):
72
+ tokenizer=AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
73
+ model = AutoModelForCausalLM.from_pretrained(
74
+ model_path,
75
+ return_dict=True,
76
+ device_map="auto",
77
+ torch_dtype = dtype,
78
+ trust_remote_code=True
79
+ )
80
+
81
+ generation_config = model.generation_config
82
+ generation_config.max_new_tokens = 1700
83
+ generation_config.min_length = 20
84
+ generation_config.temperature = 1
85
+ generation_config.top_p = 0.7
86
+ generation_config.num_return_sequences = 1
87
+ generation_config.pad_token_id = tokenizer.eos_token_id
88
+ generation_config.eos_token_id = tokenizer.eos_token_id
89
+ generation_config.repetition_penalty = 1.1
90
+
91
+ gpipeline = transformers.pipeline(
92
+ model=model,
93
+ tokenizer=tokenizer,
94
+ return_full_text=True,
95
+ task="text-generation",
96
+ stopping_criteria=stopping_criteria,
97
+ generation_config=generation_config
98
+ )
99
+
100
+ self.llm = HuggingFacePipeline(pipeline=gpipeline)
101
+
102
+ def __call__(self, data:Dict[str, Any]) -> Dict[str, Any]:
103
+ prompt = data.pop("inputs", data)
104
+ result = self.llm(prompt)
105
+ return result