File size: 1,064 Bytes
a25618b 1bac627 a25618b 6ec120b a25618b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 |
from typing import Dict, List, Any
from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer
import torch
class EndpointHandler():
def __init__(self, path=""):
# load the optimized model
self.model = M2M100ForConditionalGeneration.from_pretrained(path,torch_dtype=torch.bfloat16)
self.tokenizer = M2M100Tokenizer.from_pretrained(path)
def __call__(self, data: Dict[str,str]) -> Dict[str, str]:
"""
Args:
data (:obj:):
includes the input data and the parameters for the inference.
"""
text = data.get("text", data)
langId = data.get("langId",data)
# tokenize the input
encoded = tokenizer(text, return_tensors="pt")
encoded = encoded.to(model.device)
# run the model
generated_tokens = model.generate(**encoded, forced_bos_token_id=tokenizer.get_lang_id(langId))
result = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
# return
return {"translated": result} |