try2_deploy_falcon / handler.py
ClaudiaIoana550's picture
Create handler.py
614db54 verified
raw
history blame
No virus
3.46 kB
from typing import Any, Dict, List
from langchain.llms import HuggingFacePipeline
import torch
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer
dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16
from transformers import (
StoppingCriteria,
StoppingCriteriaList,
pipeline,
)
from typing import List
import torch
class StopGenerationCriteria(StoppingCriteria):
def __init__(self, max_duplicate_sequences=3, max_repeated_words=2):
self.generated_sequences = set()
self.max_duplicate_sequences = max_duplicate_sequences
self.max_repeated_words = max_repeated_words
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
tokenizer=AutoTokenizer.from_pretrained("ClaudiaIoana550/try1_deploy_falcon", trust_remote_code=True)
generated_sequence = input_ids.tolist()
if len(generated_sequence[0]) >= 50:
sequen = generated_sequence[0][-30:]
s_mare = str(generated_sequence[0]).strip("[]")
s_mic = str(sequen).strip("[]")
count2 = 0
if s_mic in s_mare:
count2 = sum(1 for i in range(len(generated_sequence[0]) - len(sequen) + 1) if generated_sequence[0][i:i + len(sequen)] == sequen)
if count2 >= 2:
return True
generated_tokens = [tokenizer.decode(token_id) for token_id in input_ids[0]]
count = 1
prev_token = None
for token in generated_tokens:
if token == prev_token:
count += 1
if count > self.max_repeated_words:
return True
else:
count = 1
prev_token = token
if len(self.generated_sequences) >= self.max_duplicate_sequences:
return True
return False
# Example usage:
# Define the maximum number of duplicate sequences and repeated words
max_duplicate_sequences = 1
max_repeated_words = 2
# Create an instance of StopGenerationCriteria
stop_criteria = StopGenerationCriteria(max_duplicate_sequences, max_repeated_words)
# Add the custom stopping criteria to a StoppingCriteriaList
stopping_criteria = StoppingCriteriaList([stop_criteria])
class EndpointHandler:
def __init__(self, model_path=""):
tokenizer=AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_path,
return_dict=True,
device_map="auto",
torch_dtype = dtype,
trust_remote_code=True
)
generation_config = model.generation_config
generation_config.max_new_tokens = 1700
generation_config.min_length = 20
generation_config.temperature = 1
generation_config.top_p = 0.7
generation_config.num_return_sequences = 1
generation_config.pad_token_id = tokenizer.eos_token_id
generation_config.eos_token_id = tokenizer.eos_token_id
generation_config.repetition_penalty = 1.1
gpipeline = transformers.pipeline(
model=model,
tokenizer=tokenizer,
return_full_text=True,
task="text-generation",
stopping_criteria=stopping_criteria,
generation_config=generation_config
)
self.llm = HuggingFacePipeline(pipeline=gpipeline)
def __call__(self, data:Dict[str, Any]) -> Dict[str, Any]:
prompt = data.pop("inputs", data)
result = self.llm(prompt)
return result