Jeney commited on
Commit
2e2a941
1 Parent(s): 050b504

Add all answers in output

Browse files
Files changed (1) hide show
  1. handler.py +12 -4
handler.py CHANGED
@@ -13,17 +13,25 @@ class EndpointHandler:
13
  self.model = ViltForQuestionAnswering.from_pretrained(path)
14
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
15
 
16
- def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
17
  # process input
18
  inputs = data.pop("inputs", data)
19
  image = inputs["image"]
20
  image = Image.open(io.BytesIO(eval(image)))
21
  text = inputs["text"]
22
-
23
  # preprocess
24
  encoding = self.processor(image, text, return_tensors="pt")
25
  outputs = self.model(**encoding)
26
  # postprocess the prediction
27
  logits = outputs.logits
28
- idx = logits.argmax(-1).item()
29
- return [{"best_answer": self.model.config.id2label[idx], "outputs": str(outputs)}]
 
 
 
 
 
 
 
 
 
13
  self.model = ViltForQuestionAnswering.from_pretrained(path)
14
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
15
 
16
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
17
  # process input
18
  inputs = data.pop("inputs", data)
19
  image = inputs["image"]
20
  image = Image.open(io.BytesIO(eval(image)))
21
  text = inputs["text"]
22
+
23
  # preprocess
24
  encoding = self.processor(image, text, return_tensors="pt")
25
  outputs = self.model(**encoding)
26
  # postprocess the prediction
27
  logits = outputs.logits
28
+ best_idx = logits.argmax(-1).item()
29
+ best_answer = self.model.config.id2label[best_idx]
30
+ probabilities = torch.softmax(logits, dim=-1)[0]
31
+ id2label = self.model.config.id2label
32
+ answers = []
33
+ for idx, prob in enumerate(probabilities):
34
+ answer = id2label[idx]
35
+ answer_score = float(prob)
36
+ answers.append({"answer": answer, "answer_score": answer_score})
37
+ return {"best_answer": best_answer, "answers": answers}