Spaces:
Runtime error
Runtime error
import json | |
import torch | |
from PIL import Image | |
from ruamel import yaml | |
from model import albef_model_for_vqa | |
from data.transforms import ALBEFTextTransform, testing_image_transform | |
import gradio as gr | |
data_dir = "./" | |
config = yaml.load(open("./configs/vqa.yaml", "r"), Loader=yaml.Loader) | |
model = albef_model_for_vqa(config) | |
checkpoint_url = "https://download.pytorch.org/models/multimodal/albef/finetuned_vqa_checkpoint.pt" | |
checkpoint = torch.hub.load_state_dict_from_url(checkpoint_url, map_location='cpu') | |
model.load_state_dict(checkpoint) | |
image_transform = testing_image_transform() | |
question_transform = ALBEFTextTransform(add_end_token=False) | |
answer_transform = ALBEFTextTransform(do_pre_process=False) | |
answer_list = json.load(open(data_dir + "answer_list.json", "r")) | |
def infer(image, question): | |
images = [image] | |
image_input = [image_transform(image) for image in images] | |
image_input = torch.stack(image_input, dim=0) | |
question_input = question_transform([question]) | |
question_atts = (question_input != 0).type(torch.long) | |
answer_input = answer_transform(answer_list) | |
answer_atts = (answer_input != 0).type(torch.long) | |
answer_ids, _ = model( | |
image_input, | |
question_input, | |
question_atts, | |
answer_input, | |
answer_atts, | |
k=1, | |
is_train=False, | |
) | |
predicted_answer_id = answer_ids[0] | |
predicted_answer = answer_list[predicted_answer_id] | |
return predicted_answer | |
demo = gr.Interface( | |
fn=infer, | |
inputs=[gr.Image(label='image', type='pil', image_mode='RGB'), gr.Text(label='question')], | |
outputs=gr.Text(label='answer'), | |
# examples=[ | |
# ['vqav2.png', 'What sport is this?'], | |
# ['vizwiz.jpeg', 'What piece of meat have I taken out of the freezer?'], | |
# ['aqua.png', 'what does bol lean nonchalantly on'], | |
# ['robotvqa.png', 'How many silver spoons are there?'], | |
# ] | |
) | |
demo.launch() |