wiusdy commited on
Commit
31d780f
1 Parent(s): 9c010ec

solving ViLT problem

Browse files
Files changed (2) hide show
  1. app.py +1 -2
  2. inference.py +1 -21
app.py CHANGED
@@ -8,8 +8,7 @@ inference = Inference()
8
 
9
  with gr.Blocks() as block:
10
  txt = gr.Textbox(label="Insert a question..", lines=2)
11
- outputs = [gr.outputs.Textbox(label="Answer from BLIP saffal model"), gr.outputs.Textbox(label="Answer from BLIP control net"),
12
- gr.outputs.Textbox(label="Answer from ViLT saffal model"), gr.outputs.Textbox(label="Answer from ViLT control net")]
13
 
14
  btn = gr.Button(value="Submit")
15
 
 
8
 
9
  with gr.Blocks() as block:
10
  txt = gr.Textbox(label="Insert a question..", lines=2)
11
+ outputs = [gr.outputs.Textbox(label="Answer from BLIP saffal model"), gr.outputs.Textbox(label="Answer from BLIP control net")]
 
12
 
13
  btn = gr.Button(value="Submit")
14
 
inference.py CHANGED
@@ -5,10 +5,6 @@ import torch
5
 
6
  class Inference:
7
  def __init__(self):
8
- self.vilt_processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
9
- self.vilt_model_saffal = BlipForQuestionAnswering.from_pretrained("wiusdy/vilt_saffal_model")
10
- self.vilt_model_control_net = BlipForQuestionAnswering.from_pretrained("wiusdy/vilt_control_net")
11
-
12
  self.blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
13
  self.blip_model_saffal = BlipForQuestionAnswering.from_pretrained("wiusdy/blip_pretrained_saffal_fashion_finetuning")
14
  self.blip_model_control_net = BlipForQuestionAnswering.from_pretrained("wiusdy/blip_pretrained_control_net_fashion_finetuning")
@@ -17,27 +13,11 @@ class Inference:
17
  self.logger = logging.get_logger("transformers")
18
 
19
  def inference(self, image, text):
20
- self.logger.info(f"Running inference for model ViLT Saffal")
21
- ViLT_saffal_inference = self.__inference_vilt_saffal(image, text)
22
- self.logger.info(f"Running inference for model ViLT Control Net")
23
- ViLT_control_net_inference = self.__inference_vilt_control_net(image, text)
24
  self.logger.info(f"Running inference for model BLIP Saffal")
25
  BLIP_saffal_inference = self.__inference_saffal_blip(image, text)
26
  self.logger.info(f"Running inference for model BLIP Control Net")
27
  BLIP_control_net_inference = self.__inference_control_net_blip(image, text)
28
- return BLIP_saffal_inference, BLIP_control_net_inference, ViLT_saffal_inference, ViLT_control_net_inference
29
-
30
- def __inference_vilt_saffal(self, image, text):
31
- encoding = self.vilt_processor(image, text, return_tensors="pt")
32
- out = self.vilt_model_saffal.generate(**encoding)
33
- generated_text = self.vilt_processor.decode(out[0], skip_special_tokens=True)
34
- return f"{generated_text}"
35
-
36
- def __inference_vilt_control_net(self, image, text):
37
- encoding = self.vilt_processor(image, text, return_tensors="pt")
38
- out = self.vilt_model_control_net.generate(**encoding)
39
- generated_text = self.vilt_processor.decode(out[0], skip_special_tokens=True)
40
- return f"{generated_text}"
41
 
42
  def __inference_saffal_blip(self, image, text):
43
  encoding = self.blip_processor(image, text, return_tensors="pt")
 
5
 
6
  class Inference:
7
  def __init__(self):
 
 
 
 
8
  self.blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
9
  self.blip_model_saffal = BlipForQuestionAnswering.from_pretrained("wiusdy/blip_pretrained_saffal_fashion_finetuning")
10
  self.blip_model_control_net = BlipForQuestionAnswering.from_pretrained("wiusdy/blip_pretrained_control_net_fashion_finetuning")
 
13
  self.logger = logging.get_logger("transformers")
14
 
15
  def inference(self, image, text):
 
 
 
 
16
  self.logger.info(f"Running inference for model BLIP Saffal")
17
  BLIP_saffal_inference = self.__inference_saffal_blip(image, text)
18
  self.logger.info(f"Running inference for model BLIP Control Net")
19
  BLIP_control_net_inference = self.__inference_control_net_blip(image, text)
20
+ return BLIP_saffal_inference, BLIP_control_net_inference
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  def __inference_saffal_blip(self, image, text):
23
  encoding = self.blip_processor(image, text, return_tensors="pt")