wiusdy commited on
Commit
9c010ec
1 Parent(s): fc0c2dc

making the model comparison

Browse files
Files changed (2) hide show
  1. app.py +4 -5
  2. inference.py +24 -17
app.py CHANGED
@@ -7,17 +7,16 @@ inference = Inference()
7
 
8
 
9
  with gr.Blocks() as block:
10
- options = gr.Dropdown(choices=["ViLT", "Blip Saffal", "Blip CN"], label="Models", info="Select the model to use..", )
11
- # need to improve this one...
12
-
13
  txt = gr.Textbox(label="Insert a question..", lines=2)
14
- txt_3 = gr.Textbox(value="", label="Your answer is here..")
 
 
15
  btn = gr.Button(value="Submit")
16
 
17
  dogs = os.path.join(os.path.dirname(__file__), "617.jpg")
18
  image = gr.Image(type="pil", value=dogs)
19
 
20
- btn.click(inference.inference, inputs=[options, image, txt], outputs=[txt_3])
21
 
22
  if __name__ == "__main__":
23
  block.launch()
 
7
 
8
 
9
  with gr.Blocks() as block:
 
 
 
10
  txt = gr.Textbox(label="Insert a question..", lines=2)
11
+ outputs = [gr.outputs.Textbox(label="Answer from BLIP saffal model"), gr.outputs.Textbox(label="Answer from BLIP control net"),
12
+ gr.outputs.Textbox(label="Answer from ViLT saffal model"), gr.outputs.Textbox(label="Answer from ViLT control net")]
13
+
14
  btn = gr.Button(value="Submit")
15
 
16
  dogs = os.path.join(os.path.dirname(__file__), "617.jpg")
17
  image = gr.Image(type="pil", value=dogs)
18
 
19
+ btn.click(inference.inference, inputs=[image, txt], outputs=outputs)
20
 
21
  if __name__ == "__main__":
22
  block.launch()
inference.py CHANGED
@@ -6,31 +6,38 @@ import torch
6
  class Inference:
7
  def __init__(self):
8
  self.vilt_processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
9
- self.vilt_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
 
10
 
11
  self.blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
12
  self.blip_model_saffal = BlipForQuestionAnswering.from_pretrained("wiusdy/blip_pretrained_saffal_fashion_finetuning")
13
  self.blip_model_control_net = BlipForQuestionAnswering.from_pretrained("wiusdy/blip_pretrained_control_net_fashion_finetuning")
 
14
  logging.set_verbosity_info()
15
  self.logger = logging.get_logger("transformers")
16
 
17
- def inference(self, selected, image, text):
18
- self.logger.info(f"selected model {selected}")
19
- if selected == "ViLT":
20
- return self.__inference_vilt(image, text)
21
- elif selected == "Blip Saffal":
22
- return self.__inference_saffal_blip(image, text)
23
- elif selected == "Blip CN":
24
- return self.__inference_control_net_blip(image, text)
25
- else:
26
- self.logger.warning("Please select a model to make the inference..")
27
-
28
- def __inference_vilt(self, image, text):
 
 
 
 
 
 
29
  encoding = self.vilt_processor(image, text, return_tensors="pt")
30
- outputs = self.vilt_model(**encoding)
31
- logits = outputs.logits
32
- idx = logits.argmax(-1).item()
33
- return f"{self.vilt_model.config.id2label[idx]}"
34
 
35
  def __inference_saffal_blip(self, image, text):
36
  encoding = self.blip_processor(image, text, return_tensors="pt")
 
6
  class Inference:
7
  def __init__(self):
8
  self.vilt_processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
9
+ self.vilt_model_saffal = BlipForQuestionAnswering.from_pretrained("wiusdy/vilt_saffal_model")
10
+ self.vilt_model_control_net = BlipForQuestionAnswering.from_pretrained("wiusdy/vilt_control_net")
11
 
12
  self.blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
13
  self.blip_model_saffal = BlipForQuestionAnswering.from_pretrained("wiusdy/blip_pretrained_saffal_fashion_finetuning")
14
  self.blip_model_control_net = BlipForQuestionAnswering.from_pretrained("wiusdy/blip_pretrained_control_net_fashion_finetuning")
15
+
16
  logging.set_verbosity_info()
17
  self.logger = logging.get_logger("transformers")
18
 
19
+ def inference(self, image, text):
20
+ self.logger.info(f"Running inference for model ViLT Saffal")
21
+ ViLT_saffal_inference = self.__inference_vilt_saffal(image, text)
22
+ self.logger.info(f"Running inference for model ViLT Control Net")
23
+ ViLT_control_net_inference = self.__inference_vilt_control_net(image, text)
24
+ self.logger.info(f"Running inference for model BLIP Saffal")
25
+ BLIP_saffal_inference = self.__inference_saffal_blip(image, text)
26
+ self.logger.info(f"Running inference for model BLIP Control Net")
27
+ BLIP_control_net_inference = self.__inference_control_net_blip(image, text)
28
+ return BLIP_saffal_inference, BLIP_control_net_inference, ViLT_saffal_inference, ViLT_control_net_inference
29
+
30
+ def __inference_vilt_saffal(self, image, text):
31
+ encoding = self.vilt_processor(image, text, return_tensors="pt")
32
+ out = self.vilt_model_saffal.generate(**encoding)
33
+ generated_text = self.vilt_processor.decode(out[0], skip_special_tokens=True)
34
+ return f"{generated_text}"
35
+
36
+ def __inference_vilt_control_net(self, image, text):
37
  encoding = self.vilt_processor(image, text, return_tensors="pt")
38
+ out = self.vilt_model_control_net.generate(**encoding)
39
+ generated_text = self.vilt_processor.decode(out[0], skip_special_tokens=True)
40
+ return f"{generated_text}"
 
41
 
42
  def __inference_saffal_blip(self, image, text):
43
  encoding = self.blip_processor(image, text, return_tensors="pt")