kjcjohnson's picture
Increase timeout
dfffec6
raw
history blame
2.13 kB
from typing import Dict, List, Any
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.generation.logits_process import LogitsProcessorList, InfNanRemoveLogitsProcessor
from transformers_gad.grammar_utils import IncrementalGrammarConstraint
from transformers_gad.generation.logits_process import GrammarAlignedOracleLogitsProcessor
class EndpointHandler():
def __init__(self, path=""):
# Preload
self.tokenizer = AutoTokenizer.from_pretrained(path)
self.model = AutoModelForCausalLM.from_pretrained(path)
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
# do it!
inputs = data.get("inputs",data)
grammar_str = data.get("grammar", "")
MAX_NEW_TOKENS=4096
MAX_TIME=300
print(grammar_str)
grammar = IncrementalGrammarConstraint(grammar_str, "root", self.tokenizer)
# Initialize logits processor for the grammar
gad_oracle_processor = GrammarAlignedOracleLogitsProcessor(grammar)
inf_nan_remove_processor = InfNanRemoveLogitsProcessor()
logits_processors = LogitsProcessorList([
inf_nan_remove_processor,
gad_oracle_processor,
])
input_ids = self.tokenizer([inputs], add_special_tokens=False, return_tensors="pt")["input_ids"]
output = self.model.generate(
input_ids,
do_sample=True,
max_time=MAX_TIME,
max_new_tokens=MAX_NEW_TOKENS,
logits_processor=logits_processors
)
gad_oracle_processor.reset()
# Detokenize generated output
input_length = 1 if self.model.config.is_encoder_decoder else input_ids.shape[1]
if (hasattr(output, "sequences")):
generated_tokens = output.sequences[:, input_length:]
else:
generated_tokens = output[:, input_length:]
generations = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
return generations