333pg / app.py
James332's picture
Rename app (1).py to app.py
8a9693e
import gradio as gr
from transformers import ViltProcessor, ViltForQuestionAnswering
from transformers import AutoProcessor, AutoModelForVisualQuestionAnswering
from PIL import Image
import torch
dataset_name = "Multimodal-Fatima/OK-VQA_train"
original_model_name = "microsoft/git-base-vqav2"
model_name = "hyo37009/git-vqa-finetuned-on-ok-vqa"
model_path = "git-vqa-finetuned-on-ok-vqa"
questions = ["What can happen the objects shown are thrown on the ground?",
"What was the machine beside the bowl used for?",
"What kind of cars are in the photo?",
"What is the hairstyle of the blond called?",
"How old do you have to be in canada to do this?",
"Can you guess the place where the man is playing?",
"What loony tune character is in this photo?",
"Whose birthday is being celebrated?",
"Where can that toilet seat be bought?",
"What do you call the kind of pants that the man on the right is wearing?"]
processor = AutoProcessor.from_pretrained(model_path)
model = AutoModelForVisualQuestionAnswering.from_pretrained(model_path)
def main(select_exemple_num):
selectednum = select_exemple_num
exemple_img = f"image{selectednum}.jpg"
img = Image.open(exemple_img)
question = questions[selectednum - 1]
encoding = processor(img, question, return_tensors='pt')
outputs = model(**encoding)
logits = outputs.logits
# ---
output_str = 'pridicted : \n'
predicted_classes = torch.sigmoid(logits)
probs, classes = torch.topk(predicted_classes, 5)
ans = ''
for prob, class_idx in zip(probs.squeeze().tolist(), classes.squeeze().tolist()):
print(prob, model.config.id2label[class_idx])
output_str += str(prob)
output_str += " "
output_str += model.config.id2label[class_idx]
output_str += "\n"
if not ans:
ans = model.config.id2label[class_idx]
print(ans)
# ---
output_str += f"\nso I think it's answer is : \n{ans}"
return exemple_img, question, output_str
demo = gr.Interface(
fn=main,
inputs=[gr.Slider(1, len(questions), step=1)],
outputs=["image", "text", "text"],
)
demo.launch(share=True)