Spaces:
Sleeping
Sleeping
updating code and removing bugs
Browse files- app.py +1 -1
- inference.py +13 -4
app.py
CHANGED
@@ -7,7 +7,7 @@ inference = Inference()
|
|
7 |
|
8 |
|
9 |
with gr.Blocks() as block:
|
10 |
-
options = gr.Dropdown(choices=["
|
11 |
# need to improve this one...
|
12 |
|
13 |
txt = gr.Textbox(label="Insert a question..", lines=2)
|
|
|
7 |
|
8 |
|
9 |
with gr.Blocks() as block:
|
10 |
+
options = gr.Dropdown(choices=["ViLT", "Blip Saffal", "Blip CN"], label="Models", info="Select the model to use..", )
|
11 |
# need to improve this one...
|
12 |
|
13 |
txt = gr.Textbox(label="Insert a question..", lines=2)
|
inference.py
CHANGED
@@ -9,7 +9,8 @@ class Inference:
|
|
9 |
self.vilt_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
|
10 |
|
11 |
self.blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
|
12 |
-
self.
|
|
|
13 |
logging.set_verbosity_info()
|
14 |
self.logger = logging.get_logger("transformers")
|
15 |
|
@@ -18,7 +19,9 @@ class Inference:
|
|
18 |
if selected == "Model 1":
|
19 |
return self.__inference_vilt(image, text)
|
20 |
elif selected == "Model 2":
|
21 |
-
return self.
|
|
|
|
|
22 |
else:
|
23 |
self.logger.warning("Please select a model to make the inference..")
|
24 |
|
@@ -29,8 +32,14 @@ class Inference:
|
|
29 |
idx = logits.argmax(-1).item()
|
30 |
return f"{self.vilt_model.config.id2label[idx]}"
|
31 |
|
32 |
-
def
|
33 |
encoding = self.blip_processor(image, text, return_tensors="pt").to("cuda:0", torch.float16)
|
34 |
-
out = self.
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
generated_text = self.blip_processor.decode(out[0], skip_special_tokens=True)
|
36 |
return f"{generated_text}"
|
|
|
9 |
self.vilt_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
|
10 |
|
11 |
self.blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
|
12 |
+
self.blip_model_saffal = BlipForQuestionAnswering.from_pretrained("wiusdy/blip_pretrained_saffal_fashion_finetuning").to("cuda")
|
13 |
+
self.blip_model_control_net = BlipForQuestionAnswering.from_pretrained("wiusdy/blip_pretrained_control_net_fashion_finetuning").to("cuda")
|
14 |
logging.set_verbosity_info()
|
15 |
self.logger = logging.get_logger("transformers")
|
16 |
|
|
|
19 |
if selected == "Model 1":
|
20 |
return self.__inference_vilt(image, text)
|
21 |
elif selected == "Model 2":
|
22 |
+
return self.__inference_saffal_blip(image, text)
|
23 |
+
elif selected == "Model 3":
|
24 |
+
return self.__inference_control_net_blip(image, text)
|
25 |
else:
|
26 |
self.logger.warning("Please select a model to make the inference..")
|
27 |
|
|
|
32 |
idx = logits.argmax(-1).item()
|
33 |
return f"{self.vilt_model.config.id2label[idx]}"
|
34 |
|
35 |
+
def __inference_saffal_blip(self, image, text):
|
36 |
encoding = self.blip_processor(image, text, return_tensors="pt").to("cuda:0", torch.float16)
|
37 |
+
out = self.blip_model_saffal.generate(**encoding)
|
38 |
+
generated_text = self.blip_processor.decode(out[0], skip_special_tokens=True)
|
39 |
+
return f"{generated_text}"
|
40 |
+
|
41 |
+
def __inference_control_net_blip(self, image, text):
|
42 |
+
encoding = self.blip_processor(image, text, return_tensors="pt").to("cuda:0", torch.float16)
|
43 |
+
out = self.blip_model_control_net.generate(**encoding)
|
44 |
generated_text = self.blip_processor.decode(out[0], skip_special_tokens=True)
|
45 |
return f"{generated_text}"
|