wiusdy commited on
Commit
e727785
1 Parent(s): 0306e1b

updating code and removing bugs

Browse files
Files changed (2) hide show
  1. app.py +1 -1
  2. inference.py +13 -4
app.py CHANGED
@@ -7,7 +7,7 @@ inference = Inference()
7
 
8
 
9
  with gr.Blocks() as block:
10
- options = gr.Dropdown(choices=["Model 1", "Model 2"], label="Models", info="Select the model to use..", )
11
  # need to improve this one...
12
 
13
  txt = gr.Textbox(label="Insert a question..", lines=2)
 
7
 
8
 
9
  with gr.Blocks() as block:
10
+ options = gr.Dropdown(choices=["ViLT", "Blip Saffal", "Blip CN"], label="Models", info="Select the model to use..", )
11
  # need to improve this one...
12
 
13
  txt = gr.Textbox(label="Insert a question..", lines=2)
inference.py CHANGED
@@ -9,7 +9,8 @@ class Inference:
9
  self.vilt_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
10
 
11
  self.blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
12
- self.blip_model = BlipForQuestionAnswering.from_pretrained("wiusdy/blip_pretrained_saffal_fashion_finetuning").to("cuda")
 
13
  logging.set_verbosity_info()
14
  self.logger = logging.get_logger("transformers")
15
 
@@ -18,7 +19,9 @@ class Inference:
18
  if selected == "Model 1":
19
  return self.__inference_vilt(image, text)
20
  elif selected == "Model 2":
21
- return self.__inference_deplot(image, text)
 
 
22
  else:
23
  self.logger.warning("Please select a model to make the inference..")
24
 
@@ -29,8 +32,14 @@ class Inference:
29
  idx = logits.argmax(-1).item()
30
  return f"{self.vilt_model.config.id2label[idx]}"
31
 
32
- def __inference_deplot(self, image, text):
33
  encoding = self.blip_processor(image, text, return_tensors="pt").to("cuda:0", torch.float16)
34
- out = self.blip_model.generate(**encoding)
 
 
 
 
 
 
35
  generated_text = self.blip_processor.decode(out[0], skip_special_tokens=True)
36
  return f"{generated_text}"
 
9
  self.vilt_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
10
 
11
  self.blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
12
+ self.blip_model_saffal = BlipForQuestionAnswering.from_pretrained("wiusdy/blip_pretrained_saffal_fashion_finetuning").to("cuda")
13
+ self.blip_model_control_net = BlipForQuestionAnswering.from_pretrained("wiusdy/blip_pretrained_control_net_fashion_finetuning").to("cuda")
14
  logging.set_verbosity_info()
15
  self.logger = logging.get_logger("transformers")
16
 
 
19
  if selected == "Model 1":
20
  return self.__inference_vilt(image, text)
21
  elif selected == "Model 2":
22
+ return self.__inference_saffal_blip(image, text)
23
+ elif selected == "Model 3":
24
+ return self.__inference_control_net_blip(image, text)
25
  else:
26
  self.logger.warning("Please select a model to make the inference..")
27
 
 
32
  idx = logits.argmax(-1).item()
33
  return f"{self.vilt_model.config.id2label[idx]}"
34
 
35
+ def __inference_saffal_blip(self, image, text):
36
  encoding = self.blip_processor(image, text, return_tensors="pt").to("cuda:0", torch.float16)
37
+ out = self.blip_model_saffal.generate(**encoding)
38
+ generated_text = self.blip_processor.decode(out[0], skip_special_tokens=True)
39
+ return f"{generated_text}"
40
+
41
+ def __inference_control_net_blip(self, image, text):
42
+ encoding = self.blip_processor(image, text, return_tensors="pt").to("cuda:0", torch.float16)
43
+ out = self.blip_model_control_net.generate(**encoding)
44
  generated_text = self.blip_processor.decode(out[0], skip_special_tokens=True)
45
  return f"{generated_text}"