File size: 5,009 Bytes
7decfba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import gradio as gr
import requests
from PIL import Image
from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor

def infer_infographics(image, question):
    model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-ai2d-base").to("cuda")
    processor = Pix2StructProcessor.from_pretrained("google/pix2struct-ai2d-base")

    inputs = processor(images=image, text=question, return_tensors="pt").to("cuda")

    predictions = model.generate(**inputs)
    return processor.decode(predictions[0], skip_special_tokens=True)

def infer_ui(image, question):
    model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-screen2words-base").to("cuda")
    processor = Pix2StructProcessor.from_pretrained("google/pix2struct-screen2words-base")

    inputs = processor(images=image,text=question, return_tensors="pt").to("cuda")

    predictions = model.generate(**inputs)
    return processor.decode(predictions[0], skip_special_tokens=True)

def infer_chart(image, question):
  model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-chartqa-base").to("cuda")
  processor = Pix2StructProcessor.from_pretrained("google/pix2struct-chartqa-base")

  inputs = processor(images=image, text=question, return_tensors="pt").to("cuda")

  predictions = model.generate(**inputs)
  return processor.decode(predictions[0], skip_special_tokens=True)

def infer_doc(image, question):
  model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-docvqa-base").to("cuda")
  processor = Pix2StructProcessor.from_pretrained("google/pix2struct-docvqa-base")
  inputs = processor(images=image, text=question, return_tensors="pt").to("cuda")
  predictions = model.generate(**inputs)
  return processor.decode(predictions[0], skip_special_tokens=True)

css = """
  #mkd {
    height: 500px; 
    overflow: auto; 
    border: 1px solid #ccc; 
  }
"""

with gr.Blocks(css=css) as demo:
  gr.HTML("<h1><center>Pix2Struct πŸ“„<center><h1>")
  gr.HTML("<h3><center>Pix2Struct is a powerful backbone for visual question answering. ⚑</h3>")
  gr.HTML("<h3><center>Each tab in this app demonstrates Pix2Struct models fine-tuned on document question answering, infographics question answering, question answering on user interfaces, and charts. πŸ“„πŸ“±πŸ“Š<h3>")
  gr.HTML("<h3><center>This app has base versions of each model. For better performance, use large checkpoints.<h3>")

  with gr.Tab(label="Visual Question Answering over Documents"):
    with gr.Row():
      with gr.Column():
        input_img = gr.Image(label="Input Document")
        question = gr.Text(label="Question")
        submit_btn = gr.Button(label="Submit")
      output = gr.Text(label="Answer")
    gr.Examples(
    [["docvqa_example.png", "How many items are sold?"]],
    inputs = [input_img, question],
    outputs = [output],
    fn=infer_doc,
    cache_examples=True,
    label='Click on any Examples below to get Document Question Answering results quickly πŸ‘‡'
    )

    submit_btn.click(infer_doc, [input_img, question], [output])

  with gr.Tab(label="Visual Question Answering over Infographics"):
    with gr.Row():
      with gr.Column():
        input_img = gr.Image(label="Input Image")
        question = gr.Text(label="Question")
        submit_btn = gr.Button(label="Submit")
      output = gr.Text(label="Answer")
    gr.Examples(
    [["infographics_example.jpeg", "What is this infographic about?"]],
    inputs = [input_img, question],
    outputs = [output],
    fn=infer_doc,
    cache_examples=True,
    label='Click on any Examples below to get Infographics QA results quickly πŸ‘‡'
    )

    submit_btn.click(infer_infographics, [input_img, question], [output])
  with gr.Tab(label="Caption User Interfaces"):
    with gr.Row():
      with gr.Column():
        input_img = gr.Image(label="Input UI Image")
        question = gr.Text(label="Question")
        submit_btn = gr.Button(label="Submit")
      output = gr.Text(label="Caption")
    submit_btn.click(infer_chart, [input_img, question], [output])
    gr.Examples(
    [["screen2words_ui_example.png", "What is this UI about?"]],
    inputs = [input_img, question],
    outputs = [output],
    fn=infer_doc,
    cache_examples=True,
    label='Click on any Examples below to get UI question answering results quickly πŸ‘‡'
    )

  with gr.Tab(label="Ask about Charts"):
    with gr.Row():
      with gr.Column():
        input_img = gr.Image(label="Input Chart")
        question = gr.Text(label="Question")
        submit_btn = gr.Button(label="Submit")
      output = gr.Text(label="Caption")

    submit_btn.click(infer_chart, [input_img, question], [output])
    gr.Examples(
    [["chartqa_example.png", "How much percent is bicycle?"]],
    inputs = [input_img, question],
    outputs = [output],
    fn=infer_doc,
    cache_examples=True,
    label='Click on any Examples below to get Chart question answering results quickly πŸ‘‡'
    )

demo.launch(debug=True)