File size: 2,835 Bytes
e67d9c6
3381383
e67d9c6
 
61ac553
e67d9c6
 
61ac553
 
e67d9c6
3381383
 
e67d9c6
 
3381383
 
e67d9c6
 
 
3381383
e67d9c6
 
6f130b0
 
 
 
 
e67d9c6
6f130b0
e67d9c6
 
 
f6ef903
e67d9c6
edac551
61ac553
edac551
e67d9c6
 
 
 
 
 
 
6f130b0
e67d9c6
 
 
 
cbfb835
e67d9c6
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import gradio as gr
from transformers import ViltProcessor, ViltForImagesAndTextClassification
import torch

# NLRV2 example images
torch.hub.download_url_to_file('https://lil.nlp.cornell.edu/nlvr/exs/ex0_0.jpg', 'image1.jpg')
torch.hub.download_url_to_file('https://lil.nlp.cornell.edu/nlvr/exs/ex0_1.jpg', 'image2.jpg')
torch.hub.download_url_to_file('https://lil.nlp.cornell.edu/nlvr/exs/acorns_1.jpg', 'image3.jpg')
torch.hub.download_url_to_file('https://lil.nlp.cornell.edu/nlvr/exs/acorns_6.jpg', 'image4.jpg')

processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-nlvr2")
model = ViltForImagesAndTextClassification.from_pretrained("dandelin/vilt-b32-finetuned-nlvr2")

def predict(image1, image2, text):
    # prepare inputs
    encoding = processor([image1, image2], text, return_tensors="pt")
    
    # forward pass
    with torch.no_grad():
     outputs = model(input_ids=encoding.input_ids, pixel_values=encoding.pixel_values.unsqueeze(0))
     
    logits = outputs.logits
    probs = torch.nn.functional.softmax(logits, dim=1)
    
    output = dict()
    for label, id in model.config.label2id.items():
        output[label] = probs[:,id].item()
   
    return output
   
images = [gr.inputs.Image(type="pil"), gr.inputs.Image(type="pil")]
text = gr.inputs.Textbox(lines=2, label="Sentence")
label = gr.outputs.Label(num_top_classes=2, type="confidences")

example_sentence_1 = "The left image contains twice the number of dogs as the right image, and at least two dogs in total are standing."
example_sentence_2 = "One image shows exactly two brown acorns in back-to-back caps on green foliage."
examples = [["image1.jpg", "image2.jpg", example_sentence_1], ["image3.jpg", "image4.jpg", example_sentence_2]]

title = "Interactive demo: natural language visual reasoning with ViLT"
description = "Gradio Demo for ViLT (Vision and Language Transformer), fine-tuned on NLVR2. To use it, simply upload a pair of images and type a sentence and click 'submit', or click one of the examples to load them. The model will predict whether the sentence is true or false, based on the 2 images. Read more at the links below."
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2102.03334' target='_blank'>ViLT: Vision-and-Language Transformer Without Convolution or Region Supervision</a> | <a href='https://github.com/dandelin/ViLT' target='_blank'>Github Repo</a></p>"

interface = gr.Interface(fn=predict, 
                         inputs=images + [text], 
                         outputs=label, 
                         examples=examples, 
                         title=title,
                         description=description,
                         article=article, 
                         theme="default",
                         enable_queue=True)
interface.launch(debug=True)