MinxuanQin
commited on
Commit
•
d43497c
1
Parent(s):
2e4b982
fix model load error
Browse files- app.py +2 -2
- model_loader.py +15 -1
app.py
CHANGED
@@ -17,7 +17,7 @@ df = pd.read_json('vqa_samples.json', orient="columns")
|
|
17 |
# define selector
|
18 |
model_name = st.sidebar.selectbox(
|
19 |
"Select a model: ",
|
20 |
-
('vilt', 'git', 'blip', 'vbert')
|
21 |
)
|
22 |
|
23 |
image_selector_unspecific = st.number_input(
|
@@ -41,4 +41,4 @@ question = st.text_input(f"Ask the model a question related to the image: \n"
|
|
41 |
args = load_model(model_name) # TODO: cache
|
42 |
answer = get_answer(args, image, question, model_name)
|
43 |
st.text(f"Answer by {model_name}: {answer}")
|
44 |
-
st.text(f"Ground truth: {label}")
|
|
|
17 |
# define selector
|
18 |
model_name = st.sidebar.selectbox(
|
19 |
"Select a model: ",
|
20 |
+
('vilt', 'vilt_finetuned', 'git', 'blip', 'vbert')
|
21 |
)
|
22 |
|
23 |
image_selector_unspecific = st.number_input(
|
|
|
41 |
args = load_model(model_name) # TODO: cache
|
42 |
answer = get_answer(args, image, question, model_name)
|
43 |
st.text(f"Answer by {model_name}: {answer}")
|
44 |
+
st.text(f"Ground truth (of the example): {label}")
|
model_loader.py
CHANGED
@@ -33,7 +33,10 @@ VQA_URL = "https://dl.fbaipublicfiles.com/pythia/data/answers_vqa.txt"
|
|
33 |
def load_model(name):
|
34 |
if name == "vilt":
|
35 |
processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
|
36 |
-
model = ViltForQuestionAnswering.from_pretrained("
|
|
|
|
|
|
|
37 |
elif name == "git":
|
38 |
processor = AutoProcessor.from_pretrained("microsoft/git-base-vqav2")
|
39 |
model = AutoModelForCausalLM.from_pretrained("microsoft/git-base-vqav2")
|
@@ -155,6 +158,17 @@ def get_answer(model_loader_args, img, question, model_name):
|
|
155 |
logits = outputs.logits
|
156 |
idx = logits.argmax(-1).item()
|
157 |
pred = model.config.id2label[idx]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
158 |
|
159 |
elif model_name == "git":
|
160 |
try:
|
|
|
33 |
def load_model(name):
|
34 |
if name == "vilt":
|
35 |
processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
|
36 |
+
model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
|
37 |
+
elif name == "vilt_finetuned":
|
38 |
+
processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
|
39 |
+
model = ViltForQuestionAnswering.from_pretrained("Minqin/carets_vqa_finetuned")
|
40 |
elif name == "git":
|
41 |
processor = AutoProcessor.from_pretrained("microsoft/git-base-vqav2")
|
42 |
model = AutoModelForCausalLM.from_pretrained("microsoft/git-base-vqav2")
|
|
|
158 |
logits = outputs.logits
|
159 |
idx = logits.argmax(-1).item()
|
160 |
pred = model.config.id2label[idx]
|
161 |
+
|
162 |
+
elif model_name == "vilt_finetuned":
|
163 |
+
try:
|
164 |
+
encoding = processor(images=img, text=question, return_tensors="pt")
|
165 |
+
except Exception:
|
166 |
+
return err_msg()
|
167 |
+
else:
|
168 |
+
outputs = model(**encoding)
|
169 |
+
logits = outputs.logits
|
170 |
+
idx = logits.argmax(-1).item()
|
171 |
+
pred = model.config.id2label[idx]
|
172 |
|
173 |
elif model_name == "git":
|
174 |
try:
|