|
import torch |
|
from typing import Dict, List, Any |
|
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig |
|
from transformers import StoppingCriteria, StoppingCriteriaList |
|
|
|
|
|
dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16 |
|
|
|
class StoppingCriteriaSub(StoppingCriteria): |
|
def __init__(self, stops = [], encounters=1): |
|
super().__init__() |
|
self.stops = [stop.to("cuda") for stop in stops] |
|
|
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor): |
|
for stop in self.stops: |
|
stop_len = len(stop) |
|
if input_ids.shape[1] >= stop_len: |
|
if torch.all(stop == input_ids[:, -stop_len:]).item(): |
|
return True |
|
return False |
|
|
|
class EndpointHandler: |
|
def __init__(self, path=""): |
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(path) |
|
self.model = AutoModelForCausalLM.from_pretrained(path, device_map="auto",torch_dtype=dtype) |
|
print("model loaded") |
|
|
|
def __call__(self, data: Any) -> List[List[Dict[str, float]]]: |
|
inputs = data.pop("inputs", data) |
|
parameters = data.pop("parameters", None) |
|
if parameters is None: |
|
parameters = {} |
|
|
|
prompt = inputs |
|
temperature = parameters.get("temperature", 0.8) |
|
top_p = parameters.get("top_p", 0.9) |
|
top_k = parameters.get("top_k", 0) |
|
max_new_tokens = parameters.get("max_new_tokens", 100) |
|
repetition_penalty=parameters.get("diversity_penalty",1.1) |
|
max_length=parameters.get("max_length",2048) |
|
stop_words = parameters.get("stop_words", []) |
|
num_return_sequences=parameters.get("num_return_sequences",1) |
|
|
|
generation_config = GenerationConfig( |
|
temperature=temperature, |
|
top_p=top_p, |
|
top_k=top_k, |
|
max_new_tokens=max_new_tokens, |
|
max_length=max_length, |
|
eos_token_id=self.tokenizer.eos_token_id, |
|
pad_token_id=self.tokenizer.pad_token_id, |
|
repetition_penalty=repetition_penalty, |
|
num_return_sequences=num_return_sequences, |
|
do_sample=True |
|
) |
|
|
|
|
|
input_tokens = self.tokenizer.encode(prompt, return_tensors="pt",max_length=max_length-max_new_tokens, truncation=True).to(self.model.device) |
|
|
|
|
|
truncated_prompt = self.tokenizer.decode(input_tokens.squeeze(), skip_special_tokens=True) |
|
|
|
|
|
stop_words_ids = [self.tokenizer(stop_word, return_tensors='pt')['input_ids'].squeeze() for stop_word in stop_words] |
|
stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)]) |
|
|
|
|
|
attention_mask = torch.ones_like(input_tokens).to(self.model.device) |
|
|
|
|
|
output = self.model.generate(input_tokens, |
|
generation_config=generation_config, |
|
stopping_criteria=stopping_criteria, |
|
attention_mask=attention_mask, |
|
) |
|
|
|
output_text = self.tokenizer.batch_decode(output, skip_special_tokens = True)[0][len(truncated_prompt):] |
|
|
|
return [{"generated_text": output_text}] |