File size: 1,064 Bytes
a25618b
 
 
 
 
 
 
 
 
1bac627
a25618b
 
 
 
 
 
 
 
 
 
 
 
 
6ec120b
a25618b
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
from typing import  Dict, List, Any
from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer
import torch



class EndpointHandler():
    def __init__(self, path=""):
        # load the optimized model
        self.model = M2M100ForConditionalGeneration.from_pretrained(path,torch_dtype=torch.bfloat16)
        self.tokenizer = M2M100Tokenizer.from_pretrained(path)

    def __call__(self, data: Dict[str,str]) -> Dict[str, str]:
        """
        Args:
            data (:obj:):
                includes the input data and the parameters for the inference.
        """
        text = data.get("text", data)
        langId = data.get("langId",data)

        # tokenize the input
        encoded = tokenizer(text, return_tensors="pt")
        encoded = encoded.to(model.device)
        # run the model
        generated_tokens = model.generate(**encoded, forced_bos_token_id=tokenizer.get_lang_id(langId))
        result = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
        # return
        return {"translated": result}