import json import torch from PIL import Image from ruamel import yaml from model import albef_model_for_vqa from data.transforms import ALBEFTextTransform, testing_image_transform import gradio as gr data_dir = "./" config = yaml.load(open("./configs/vqa.yaml", "r"), Loader=yaml.Loader) model = albef_model_for_vqa(config) checkpoint_url = "https://download.pytorch.org/models/multimodal/albef/finetuned_vqa_checkpoint.pt" checkpoint = torch.hub.load_state_dict_from_url(checkpoint_url, map_location='cpu') model.load_state_dict(checkpoint) image_transform = testing_image_transform() question_transform = ALBEFTextTransform(add_end_token=False) answer_transform = ALBEFTextTransform(do_pre_process=False) vqa_data = json.load(open(data_dir + "vqa_data.json", "r")) answer_list = json.load(open(data_dir + "answer_list.json", "r")) examples = [[data['image'], data['question']] for data in vqa_data] title = 'VQA with ALBEF' description = 'VQA with [ALBEF](https://arxiv.org/abs/2107.07651), adapted from the [torchmultimodal example notebook](https://github.com/facebookresearch/multimodal/blob/main/examples/albef/vqa_with_albef.ipynb).' article = '''```bibtex @article{li2021align, title={Align before fuse: Vision and language representation learning with momentum distillation}, author={Li, Junnan and Selvaraju, Ramprasaath and Gotmare, Akhilesh and Joty, Shafiq and Xiong, Caiming and Hoi, Steven Chu Hong}, journal={Advances in neural information processing systems}, volume={34}, pages={9694--9705}, year={2021} } ```''' def infer(image, question): images = [image] image_input = [image_transform(image) for image in images] image_input = torch.stack(image_input, dim=0) question_input = question_transform([question]) question_atts = (question_input != 0).type(torch.long) answer_input = answer_transform(answer_list) answer_atts = (answer_input != 0).type(torch.long) answer_ids, _ = model( image_input, question_input, question_atts, answer_input, answer_atts, k=1, is_train=False, ) predicted_answer_id = answer_ids[0] predicted_answer = answer_list[predicted_answer_id] return predicted_answer demo = gr.Interface( fn=infer, inputs=[gr.Image(label='image', type='pil', image_mode='RGB'), gr.Text(label='question')], outputs=gr.Text(label='answer'), examples=examples, title=title, description=description, article=article ) demo.launch()