wiusdy commited on
Commit
6e61211
1 Parent(s): cece4fa
Files changed (1) hide show
  1. inference.py +4 -0
inference.py CHANGED
@@ -1,4 +1,5 @@
1
  from transformers import ViltProcessor, ViltForQuestionAnswering, Pix2StructProcessor, Pix2StructForConditionalGeneration, Blip2Processor, Blip2ForConditionalGeneration
 
2
 
3
  class Inference:
4
  def __init__(self):
@@ -9,12 +10,15 @@ class Inference:
9
  self.deplot_model = Pix2StructForConditionalGeneration.from_pretrained('google/deplot')
10
 
11
  def inference(self, selected, image, text):
 
12
  if selected == "Model 1":
13
  return self.__inference_deplot(image, text)
14
  elif selected == "Model 2":
15
  return self.__inference_deplot(image, text)
16
  elif selected == "Model 3":
17
  return self.__inference_vilt(image, text)
 
 
18
 
19
  def __inference_vilt(self, image, text):
20
  encoding = self.vilt_processor(image, text, return_tensors="pt")
 
1
  from transformers import ViltProcessor, ViltForQuestionAnswering, Pix2StructProcessor, Pix2StructForConditionalGeneration, Blip2Processor, Blip2ForConditionalGeneration
2
+ from transformers.utils import logging
3
 
4
  class Inference:
5
  def __init__(self):
 
10
  self.deplot_model = Pix2StructForConditionalGeneration.from_pretrained('google/deplot')
11
 
12
  def inference(self, selected, image, text):
13
+ logger.info(f"selected model {selected}")
14
  if selected == "Model 1":
15
  return self.__inference_deplot(image, text)
16
  elif selected == "Model 2":
17
  return self.__inference_deplot(image, text)
18
  elif selected == "Model 3":
19
  return self.__inference_vilt(image, text)
20
+ else:
21
+ logger.warning("Please select a model to make the inference..")
22
 
23
  def __inference_vilt(self, image, text):
24
  encoding = self.vilt_processor(image, text, return_tensors="pt")