wiusdy commited on
Commit
c5797b7
1 Parent(s): 787b7f5

using 3 new models of VQA

Browse files
Files changed (2) hide show
  1. app.py +11 -3
  2. inference.py +30 -9
app.py CHANGED
@@ -1,9 +1,10 @@
1
  import gradio as gr
2
  import os
3
 
4
- from inference import *
5
  from utils import *
6
 
 
7
 
8
  with gr.Blocks() as block:
9
  options = gr.Dropdown(choices=["Model 1", "Model 2", "Model 3"], label="Models", info="Select the model to use..")
@@ -15,9 +16,16 @@ with gr.Blocks() as block:
15
  dogs = os.path.join(os.path.dirname(__file__), "617.jpg")
16
  image = gr.Image(type="pil", value=dogs)
17
 
18
- btn.click(inference, inputs=[image, txt], outputs=[txt_3])
19
 
20
  btn = gr.Button(value="Submit")
21
 
 
 
 
 
 
 
 
22
  if __name__ == "__main__":
23
- block.launch()
 
1
  import gradio as gr
2
  import os
3
 
4
+ from inference import Inference
5
  from utils import *
6
 
7
+ inference = Inference()
8
 
9
  with gr.Blocks() as block:
10
  options = gr.Dropdown(choices=["Model 1", "Model 2", "Model 3"], label="Models", info="Select the model to use..")
 
16
  dogs = os.path.join(os.path.dirname(__file__), "617.jpg")
17
  image = gr.Image(type="pil", value=dogs)
18
 
19
+ btn.click(inference.inference_vilt, inputs=[image, txt], outputs=[txt_3])
20
 
21
  btn = gr.Button(value="Submit")
22
 
23
+ iface = gr.Interface(
24
+ fn=lambda: "Selected Model: " + options.value,
25
+ inputs=block,
26
+ outputs="text",
27
+ live=False
28
+ )
29
+
30
  if __name__ == "__main__":
31
+ iface.launch()
inference.py CHANGED
@@ -1,11 +1,32 @@
1
- from transformers import ViltProcessor, ViltForQuestionAnswering
 
2
 
3
- def inference(image, text):
4
- processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
5
- model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
6
- encoding = processor(image, text, return_tensors="pt")
7
 
8
- outputs = model(**encoding)
9
- logits = outputs.logits
10
- idx = logits.argmax(-1).item()
11
- return f"{model.config.id2label[idx]}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import ViltProcessor, ViltForQuestionAnswering, Pix2StructProcessor, Pix2StructForConditionalGeneration, Blip2Processor, Blip2ForConditionalGeneration
2
+ import torch
3
 
4
+ class Inference:
5
+ def __init__(self):
6
+ self.vilt_processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
7
+ self.vilt_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
8
 
9
+ self.deplot_processor = Pix2StructProcessor.from_pretrained('google/deplot')
10
+ self.deplot_model = Pix2StructForConditionalGeneration.from_pretrained('google/deplot')
11
+
12
+ self.blip_processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
13
+ self.blip_model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", load_in_8bit=True, device_map={"": 0}, torch_dtype=torch.float16)
14
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
15
+
16
+ def inference_vilt(self, image, text):
17
+ encoding = self.vilt_processor(image, text, return_tensors="pt")
18
+ outputs = self.vilt_model(**encoding)
19
+ logits = outputs.logits
20
+ idx = logits.argmax(-1).item()
21
+ return f"{self.vilt_model.config.id2label[idx]}"
22
+
23
+ def inference_deplot(self, image, text):
24
+ inputs = self.deplot_processor(images=image, text=text, return_tensors="pt")
25
+ predictions = self.deplot_model.generate(**inputs, max_new_tokens=512)
26
+ return f"{self.deplot_processor.decode(predictions[0], skip_special_tokens=True)}"
27
+
28
+ def inference_vilt(self, image, text):
29
+ inputs = self.blip_processor(images=image, return_tensors="pt").to(self.device, torch.float16)
30
+ generated_ids = self.blip_model.generate(**inputs)
31
+ generated_text = self.blip_processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
32
+ return f"{generated_text}"