|
from typing import Dict, List, Any |
|
from transformers import AutoModelForSeq2SeqLM, NllbTokenizerFast |
|
|
|
|
|
class EndpointHandler(): |
|
def __init__(self, path=""): |
|
|
|
self.model = AutoModelForSeq2SeqLM.from_pretrained(path,load_in_4bit=True) |
|
self.tokenizer = NllbTokenizerFast.from_pretrained(path) |
|
|
|
def __call__(self, data: Dict[str,str]) -> Dict[str, str]: |
|
""" |
|
Args: |
|
data (:obj:): |
|
includes the input data and the parameters for the inference. |
|
""" |
|
text = data.get("text", data) |
|
langId = data.get("langId",data) |
|
|
|
|
|
inputs = tokenizer(text, return_tensors="pt") |
|
|
|
translated_tokens = model.generate(**inputs, forced_bos_token_id=tokenizer.lang_code_to_id[langId], max_length=512) |
|
res = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0] |
|
|
|
return {"translated": res} |