ikeno-ada commited on
Commit
a25618b
·
verified ·
1 Parent(s): cbda76b

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +28 -0
handler.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer
3
+ import torch
4
+
5
+
6
+
7
+ class EndpointHandler():
8
+ def __init__(self, path=""):
9
+ # load the optimized model
10
+ self.model = M2M100ForConditionalGeneration.from_pretrained(path, device_map= "auto",torch_dtype=torch.bfloat16)
11
+ self.tokenizer = M2M100Tokenizer.from_pretrained(path)
12
+
13
+ def __call__(self, data: Dict[str,str]) -> Dict[str, str]:
14
+ """
15
+ Args:
16
+ data (:obj:):
17
+ includes the input data and the parameters for the inference.
18
+ """
19
+ text = data.get("text", data)
20
+ langId = data.get("langId",data)
21
+
22
+ # tokenize the input
23
+ encoded = tokenizer(text, return_tensors="pt")
24
+ # run the model
25
+ generated_tokens = model.generate(**encoded, forced_bos_token_id=tokenizer.get_lang_id(langId))
26
+ result = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
27
+ # return
28
+ return {"translated": result}