import torch from typing import Dict, List, Any from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig from transformers import StoppingCriteria, StoppingCriteriaList # get dtype dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16 class StoppingCriteriaSub(StoppingCriteria): def __init__(self, stops = [], encounters=1): super().__init__() self.stops = [stop.to("cuda") for stop in stops] def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor): for stop in self.stops: stop_len = len(stop) if input_ids.shape[1] >= stop_len: if torch.all(stop == input_ids[:, -stop_len:]).item(): return True return False class EndpointHandler: def __init__(self, path=""): # load the model self.tokenizer = AutoTokenizer.from_pretrained(path) self.model = AutoModelForCausalLM.from_pretrained(path, device_map="auto",torch_dtype=dtype) print("model loaded") def __call__(self, data: Any) -> List[List[Dict[str, float]]]: inputs = data.pop("inputs", data) parameters = data.pop("parameters", None) if parameters is None: parameters = {} prompt = inputs temperature = parameters.get("temperature", 0.8) top_p = parameters.get("top_p", 0.9) top_k = parameters.get("top_k", 0) max_new_tokens = parameters.get("max_new_tokens", 100) repetition_penalty=parameters.get("diversity_penalty",1.1) max_length=parameters.get("max_length",2048) stop_words = parameters.get("stop_words", []) num_return_sequences=parameters.get("num_return_sequences",1) generation_config = GenerationConfig( temperature=temperature, top_p=top_p, top_k=top_k, max_new_tokens=max_new_tokens, max_length=max_length, eos_token_id=self.tokenizer.eos_token_id, pad_token_id=self.tokenizer.pad_token_id, repetition_penalty=repetition_penalty, num_return_sequences=num_return_sequences, do_sample=True ) # Tokenize inputs input_tokens = self.tokenizer.encode(prompt, return_tensors="pt",max_length=max_length-max_new_tokens, truncation=True).to(self.model.device) # Decode truncated prompt truncated_prompt = self.tokenizer.decode(input_tokens.squeeze(), skip_special_tokens=True) stop_words_ids = [self.tokenizer(stop_word, return_tensors='pt')['input_ids'].squeeze() for stop_word in stop_words] stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)]) # Create attention mask attention_mask = torch.ones_like(input_tokens).to(self.model.device) # Run the model output = self.model.generate(input_tokens, generation_config=generation_config, stopping_criteria=stopping_criteria, attention_mask=attention_mask, ) #only return the part after the prompt output_text = self.tokenizer.batch_decode(output, skip_special_tokens = True)[0][len(truncated_prompt):] return [{"generated_text": output_text}]