File size: 5,095 Bytes
b395c00
 
 
cbadb1a
b395c00
 
 
 
 
 
 
 
de2dda2
 
b395c00
 
 
 
de2dda2
 
b395c00
 
 
 
 
 
 
 
de2dda2
 
b395c00
 
 
 
 
 
 
 
 
 
 
a2b1833
fe6ca74
b395c00
 
 
 
 
 
 
 
 
a2b1833
fe6ca74
b395c00
 
 
 
 
 
 
 
 
 
 
bbec7cd
6b4bc05
 
b395c00
ab02472
 
cbadb1a
b395c00
 
 
 
 
 
de2dda2
b395c00
 
 
de2dda2
b395c00
 
 
2ce2973
b395c00
 
 
c567393
b395c00
 
 
 
 
 
 
 
 
 
 
 
 
2753619
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import gradio as gr
from transformers import AutoProcessor, AutoModelForCausalLM, BlipForQuestionAnswering, ViltForQuestionAnswering
import torch
import math

torch.hub.download_url_to_file('http://images.cocodataset.org/val2017/000000039769.jpg', 'cats.jpg')
torch.hub.download_url_to_file('https://huggingface.co/datasets/nielsr/textcaps-sample/resolve/main/stop_sign.png', 'stop_sign.png')
torch.hub.download_url_to_file('https://cdn.openai.com/dall-e-2/demos/text2im/astronaut/horse/photo/0.jpg', 'astronaut.jpg')

git_processor_base = AutoProcessor.from_pretrained("microsoft/git-base-vqav2")
git_model_base = AutoModelForCausalLM.from_pretrained("microsoft/git-base-vqav2")

# git_processor_large = AutoProcessor.from_pretrained("microsoft/git-large-vqav2")
# git_model_large = AutoModelForCausalLM.from_pretrained("microsoft/git-large-vqav2")

blip_processor_base = AutoProcessor.from_pretrained("Salesforce/blip-vqa-base")
blip_model_base = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base")

# blip_processor_large = AutoProcessor.from_pretrained("Salesforce/blip-vqa-capfilt-large")
# blip_model_large = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-capfilt-large")

vilt_processor = AutoProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
vilt_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")

device = "cuda" if torch.cuda.is_available() else "cpu"

git_model_base.to(device)
blip_model_base.to(device)
#git_model_large.to(device)
#blip_model_large.to(device)
vilt_model.to(device)

def generate_answer_git(processor, model, image, question):
    # prepare image
    pixel_values = processor(images=image, return_tensors="pt").pixel_values

    # prepare question
    input_ids = processor(text=question, add_special_tokens=False).input_ids
    input_ids = [processor.tokenizer.cls_token_id] + input_ids
    input_ids = torch.tensor(input_ids).unsqueeze(0)
    
    generated_ids = model.generate(pixel_values=pixel_values, input_ids=input_ids, max_length=50, output_scores=True)
    print(generated_ids)
    generated_answer = processor.batch_decode(generated_ids, skip_special_tokens=True)
   
    return generated_answer


def generate_answer_blip(processor, model, image, question):
    # prepare image + question
    inputs = processor(images=image, text=question, return_tensors="pt")
    
    generated_ids = model.generate(**inputs, max_length=50, output_scores=True)
    print(generated_ids)
    generated_answer = processor.batch_decode(generated_ids, skip_special_tokens=True)
   
    return generated_answer


def generate_answer_vilt(processor, model, image, question):
    # prepare image + question
    encoding = processor(images=image, text=question, return_tensors="pt")

    with torch.no_grad():
        outputs = model(**encoding)
    print(outputs.logits)
    lsm = torch.nn.LogSoftmax(dim=1)
    print(lsm(outputs.logits))
    predicted_class_idx = outputs.logits.argmax(-1).item()
    logitsList = outputs.logits.tolist()
    print(logitsList)
    maybeProbsList = [math.exp(i) for i in logitsList]
    return model.config.id2label[predicted_class_idx]


def generate_answers(image, question):
    answer_git_base = generate_answer_git(git_processor_base, git_model_base, image, question)

    # answer_git_large = generate_answer_git(git_processor_large, git_model_large, image, question)

    answer_blip_base = generate_answer_blip(blip_processor_base, blip_model_base, image, question)

    # answer_blip_large = generate_answer_blip(blip_processor_large, blip_model_large, image, question)

    answer_vilt = generate_answer_vilt(vilt_processor, vilt_model, image, question)

    return answer_git_base, answer_blip_base, answer_vilt

   
examples = [["cats.jpg", "How many cats are there?"], ["stop_sign.png", "What's behind the stop sign?"], ["astronaut.jpg", "What's the astronaut riding on?"]]
outputs = [gr.outputs.Textbox(label="Answer generated by GIT-base"), gr.outputs.Textbox(label="Answer generated by BLIP-base"), gr.outputs.Textbox(label="Answer generated by ViLT")] 

title = "Interactive demo: comparing visual question answering (VQA) models"
description = "Gradio Demo to compare GIT, BLIP and ViLT, 3 state-of-the-art vision+language models. To use it, simply upload your image and click 'submit', or click one of the examples to load them. Read more at the links below."
article = "<p style='text-align: center'><a href='https://huggingface.co/docs/transformers/main/model_doc/blip' target='_blank'>BLIP docs</a> | <a href='https://huggingface.co/docs/transformers/main/model_doc/git' target='_blank'>GIT docs</a></p>"

interface = gr.Interface(fn=generate_answers, 
                         inputs=[gr.inputs.Image(type="pil"), gr.inputs.Textbox(label="Question")],
                         outputs=outputs,
                         examples=examples, 
                         title=title,
                         description=description,
                         article=article, 
                         enable_queue=True)
interface.launch(debug=True)