import json import torch from PIL import Image from ruamel import yaml from model import albef_model_for_vqa from data.transforms import ALBEFTextTransform, testing_image_transform import gradio as gr data_dir = "./" config = yaml.load(open("./configs/vqa.yaml", "r"), Loader=yaml.Loader) model = albef_model_for_vqa(config) checkpoint_url = "https://download.pytorch.org/models/multimodal/albef/finetuned_vqa_checkpoint.pt" checkpoint = torch.hub.load_state_dict_from_url(checkpoint_url, map_location='cpu') model.load_state_dict(checkpoint) image_transform = testing_image_transform() question_transform = ALBEFTextTransform(add_end_token=False) answer_transform = ALBEFTextTransform(do_pre_process=False) answer_list = json.load(open(data_dir + "answer_list.json", "r")) def infer(image, question): images = [image] image_input = [image_transform(image) for image in images] image_input = torch.stack(image_input, dim=0) question_input = question_transform([question]) question_atts = (question_input != 0).type(torch.long) answer_input = answer_transform(answer_list) answer_atts = (answer_input != 0).type(torch.long) answer_ids, _ = model( image_input, question_input, question_atts, answer_input, answer_atts, k=1, is_train=False, ) predicted_answer_id = answer_ids[0] predicted_answer = answer_list[predicted_answer_id] return predicted_answer demo = gr.Interface( fn=infer, inputs=[gr.Image(label='image', type='pil', image_mode='RGB'), gr.Text(label='question')], outputs=gr.Text(label='answer'), # examples=[ # ['vqav2.png', 'What sport is this?'], # ['vizwiz.jpeg', 'What piece of meat have I taken out of the freezer?'], # ['aqua.png', 'what does bol lean nonchalantly on'], # ['robotvqa.png', 'How many silver spoons are there?'], # ] ) demo.launch()