Spaces:
Sleeping
Sleeping
logging
Browse files- inference.py +4 -0
inference.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
from transformers import ViltProcessor, ViltForQuestionAnswering, Pix2StructProcessor, Pix2StructForConditionalGeneration, Blip2Processor, Blip2ForConditionalGeneration
|
|
|
2 |
|
3 |
class Inference:
|
4 |
def __init__(self):
|
@@ -9,12 +10,15 @@ class Inference:
|
|
9 |
self.deplot_model = Pix2StructForConditionalGeneration.from_pretrained('google/deplot')
|
10 |
|
11 |
def inference(self, selected, image, text):
|
|
|
12 |
if selected == "Model 1":
|
13 |
return self.__inference_deplot(image, text)
|
14 |
elif selected == "Model 2":
|
15 |
return self.__inference_deplot(image, text)
|
16 |
elif selected == "Model 3":
|
17 |
return self.__inference_vilt(image, text)
|
|
|
|
|
18 |
|
19 |
def __inference_vilt(self, image, text):
|
20 |
encoding = self.vilt_processor(image, text, return_tensors="pt")
|
|
|
1 |
from transformers import ViltProcessor, ViltForQuestionAnswering, Pix2StructProcessor, Pix2StructForConditionalGeneration, Blip2Processor, Blip2ForConditionalGeneration
|
2 |
+
from transformers.utils import logging
|
3 |
|
4 |
class Inference:
|
5 |
def __init__(self):
|
|
|
10 |
self.deplot_model = Pix2StructForConditionalGeneration.from_pretrained('google/deplot')
|
11 |
|
12 |
def inference(self, selected, image, text):
|
13 |
+
logger.info(f"selected model {selected}")
|
14 |
if selected == "Model 1":
|
15 |
return self.__inference_deplot(image, text)
|
16 |
elif selected == "Model 2":
|
17 |
return self.__inference_deplot(image, text)
|
18 |
elif selected == "Model 3":
|
19 |
return self.__inference_vilt(image, text)
|
20 |
+
else:
|
21 |
+
logger.warning("Please select a model to make the inference..")
|
22 |
|
23 |
def __inference_vilt(self, image, text):
|
24 |
encoding = self.vilt_processor(image, text, return_tensors="pt")
|