paligemma-doc / app.py
merve's picture
merve HF staff
Update app.py
b872a0b verified
raw
history blame
No virus
5.3 kB
import gradio as gr
import requests
from PIL import Image
from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor
import spaces
@spaces.GPU
def infer_diagram(image, question):
model = PaliGemmaForConditionalGeneration.from_pretrained("google/paligemma-3b-ft-ai2d-448").to("cuda")
processor = PaliGemmaProcessor.from_pretrained("google/paligemma-3b-ft-ai2d-448")
inputs = processor(images=image, text=question, return_tensors="pt").to("cuda")
predictions = model.generate(**inputs, max_new_tokens=100)
return processor.decode(predictions[0], skip_special_tokens=True)[len(question):].lstrip("\n")
@spaces.GPU
def infer_ocrvqa(image, question):
model = PaliGemmaForConditionalGeneration.from_pretrained("google/paligemma-3b-ft-ocrvqa-896").to("cuda")
processor = PaliGemmaProcessor.from_pretrained("google/paligemma-3b-ft-ocrvqa-896e")
inputs = processor(images=image,text=question, return_tensors="pt").to("cuda")
predictions = model.generate(**inputs, max_new_tokens=100)
return processor.decode(predictions[0], skip_special_tokens=True)[len(question):].lstrip("\n")
@spaces.GPU
def infer_infographics(image, question):
model = PaliGemmaForConditionalGeneration.from_pretrained("google/paligemma-3b-ft-infovqa-896").to("cuda")
processor = PaliGemmaProcessor.from_pretrained("google/paligemma-3b-ft-infovqa-896")
inputs = processor(images=image, text=question, return_tensors="pt").to("cuda")
predictions = model.generate(**inputs, max_new_tokens=100)
return processor.decode(predictions[0], skip_special_tokens=True)[len(question):].lstrip("\n")
@spaces.GPU
def infer_doc(image, question):
model = PaliGemmaForConditionalGeneration.from_pretrained("google/paligemma-3b-ft-docvqa-896").to("cuda")
processor = PaliGemmaProcessor.from_pretrained("google/paligemma-3b-ft-docvqa-896")
inputs = processor(images=image, text=question, return_tensors="pt").to("cuda")
predictions = model.generate(**inputs, max_new_tokens=100)
return processor.decode(predictions[0], skip_special_tokens=True)[len(question):].lstrip("\n")
css = """
#mkd {
height: 500px;
overflow: auto;
border: 1px solid #ccc;
}
"""
with gr.Blocks(css=css) as demo:
gr.HTML("<h1><center>PaliGemma Fine-tuned on Documents πŸ“„<center><h1>")
gr.HTML("<h3><center>This Space is built for you to compare different PaliGemma models fine-tuned on document tasks. ⚑</h3>")
gr.HTML("<h3><center>Each tab in this app demonstrates PaliGemma models fine-tuned on document question answering, infographics question answering, diagram understanding, and reading comprehension from images. πŸ“„πŸ“•πŸ“Š<h3>")
gr.HTML("<h3><center>Models are downloaded on the go, so first inference in each tab might take time if it's not already downloaded.<h3>")
with gr.Tab(label="Visual Question Answering over Documents"):
with gr.Row():
with gr.Column():
input_img = gr.Image(label="Input Document")
question = gr.Text(label="Question")
submit_btn = gr.Button(value="Submit")
output = gr.Text(label="Answer")
gr.Examples(
[["assets/docvqa_example.png", "How many items are sold?"]],
inputs = [input_img, question],
outputs = [output],
fn=infer_doc,
label='Click on any Examples below to get Document Question Answering results quickly πŸ‘‡'
)
submit_btn.click(infer_doc, [input_img, question], [output])
with gr.Tab(label="Visual Question Answering over Infographics"):
with gr.Row():
with gr.Column():
input_img = gr.Image(label="Input Image")
question = gr.Text(label="Question")
submit_btn = gr.Button(value="Submit")
output = gr.Text(label="Answer")
gr.Examples(
[["assets/infographics_example (1).jpeg", "What is this infographic about?"]],
inputs = [input_img, question],
outputs = [output],
fn=infer_infographics,
label='Click on any Examples below to get Infographics QA results quickly πŸ‘‡'
)
submit_btn.click(infer_infographics, [input_img, question], [output])
with gr.Tab(label="Reading from Images"):
with gr.Row():
with gr.Column():
input_img = gr.Image(label="Input Document")
question = gr.Text(label="Question")
submit_btn = gr.Button(value="Submit")
output = gr.Text(label="Infer")
submit_btn.click(infer_ocrvqa, [input_img, question], [output])
gr.Examples(
[["assets/ocrvqa.jpg", "Who is the author of this book?"]],
inputs = [input_img, question],
outputs = [output],
fn=infer_doc,
label='Click on any Examples below to get image reading comprehension results quickly πŸ‘‡'
)
with gr.Tab(label="Diagram Understanding"):
with gr.Row():
with gr.Column():
input_img = gr.Image(label="Input Diagram")
question = gr.Text(label="Question")
submit_btn = gr.Button(value="Submit")
output = gr.Text(label="Infer")
submit_btn.click(infer_diagram, [input_img, question], [output])
gr.Examples(
[["assets/diagram.png", "What is the diagram showing?"]],
inputs = [input_img, question],
outputs = [output],
fn=infer_doc,
label='Click on any Examples below to get diagram understanding results quickly πŸ‘‡'
)
demo.launch(debug=True)