pix2struct / app.py
merve's picture
merve HF staff
migrate to zero
ee97a5d verified
raw
history blame
5.07 kB
import gradio as gr
import requests
from PIL import Image
from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor
import spaces
@spaces.GPU
def infer_infographics(image, question):
model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-ai2d-base").to("cuda")
processor = Pix2StructProcessor.from_pretrained("google/pix2struct-ai2d-base")
inputs = processor(images=image, text=question, return_tensors="pt").to("cuda")
predictions = model.generate(**inputs)
return processor.decode(predictions[0], skip_special_tokens=True)
@spaces.GPU
def infer_ui(image, question):
model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-screen2words-base").to("cuda")
processor = Pix2StructProcessor.from_pretrained("google/pix2struct-screen2words-base")
inputs = processor(images=image,text=question, return_tensors="pt").to("cuda")
predictions = model.generate(**inputs)
return processor.decode(predictions[0], skip_special_tokens=True)
@spaces.GPU
def infer_chart(image, question):
model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-chartqa-base").to("cuda")
processor = Pix2StructProcessor.from_pretrained("google/pix2struct-chartqa-base")
inputs = processor(images=image, text=question, return_tensors="pt").to("cuda")
predictions = model.generate(**inputs)
return processor.decode(predictions[0], skip_special_tokens=True)
@spaces.GPU
def infer_doc(image, question):
model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-docvqa-base").to("cuda")
processor = Pix2StructProcessor.from_pretrained("google/pix2struct-docvqa-base")
inputs = processor(images=image, text=question, return_tensors="pt").to("cuda")
predictions = model.generate(**inputs)
return processor.decode(predictions[0], skip_special_tokens=True)
css = """
#mkd {
height: 500px;
overflow: auto;
border: 1px solid #ccc;
}
"""
with gr.Blocks(css=css) as demo:
gr.HTML("<h1><center>Pix2Struct πŸ“„<center><h1>")
gr.HTML("<h3><center>Pix2Struct is a powerful backbone for visual question answering. ⚑</h3>")
gr.HTML("<h3><center>Each tab in this app demonstrates Pix2Struct models fine-tuned on document question answering, infographics question answering, question answering on user interfaces, and charts. πŸ“„πŸ“±πŸ“Š<h3>")
gr.HTML("<h3><center>This app has base versions of each model. For better performance, use large checkpoints.<h3>")
with gr.Tab(label="Visual Question Answering over Documents"):
with gr.Row():
with gr.Column():
input_img = gr.Image(label="Input Document")
question = gr.Text(label="Question")
submit_btn = gr.Button(label="Submit")
output = gr.Text(label="Answer")
gr.Examples(
[["docvqa_example.png", "How many items are sold?"]],
inputs = [input_img, question],
outputs = [output],
fn=infer_doc,
cache_examples=True,
label='Click on any Examples below to get Document Question Answering results quickly πŸ‘‡'
)
submit_btn.click(infer_doc, [input_img, question], [output])
with gr.Tab(label="Visual Question Answering over Infographics"):
with gr.Row():
with gr.Column():
input_img = gr.Image(label="Input Image")
question = gr.Text(label="Question")
submit_btn = gr.Button(label="Submit")
output = gr.Text(label="Answer")
gr.Examples(
[["infographics_example.jpeg", "What is this infographic about?"]],
inputs = [input_img, question],
outputs = [output],
fn=infer_doc,
cache_examples=True,
label='Click on any Examples below to get Infographics QA results quickly πŸ‘‡'
)
submit_btn.click(infer_infographics, [input_img, question], [output])
with gr.Tab(label="Caption User Interfaces"):
with gr.Row():
with gr.Column():
input_img = gr.Image(label="Input UI Image")
question = gr.Text(label="Question")
submit_btn = gr.Button(label="Submit")
output = gr.Text(label="Caption")
submit_btn.click(infer_chart, [input_img, question], [output])
gr.Examples(
[["screen2words_ui_example.png", "What is this UI about?"]],
inputs = [input_img, question],
outputs = [output],
fn=infer_doc,
cache_examples=True,
label='Click on any Examples below to get UI question answering results quickly πŸ‘‡'
)
with gr.Tab(label="Ask about Charts"):
with gr.Row():
with gr.Column():
input_img = gr.Image(label="Input Chart")
question = gr.Text(label="Question")
submit_btn = gr.Button(label="Submit")
output = gr.Text(label="Caption")
submit_btn.click(infer_chart, [input_img, question], [output])
gr.Examples(
[["chartqa_example.png", "How much percent is bicycle?"]],
inputs = [input_img, question],
outputs = [output],
fn=infer_doc,
cache_examples=True,
label='Click on any Examples below to get Chart question answering results quickly πŸ‘‡'
)
demo.launch(debug=True)