donut-docvqa / app.py
nielsr's picture
nielsr HF staff
Create new file
994b44c
raw
history blame
2.36 kB
import re
from PIL import Image
import gradio as gr
import torch
from transformers import DonutProcessor, VisionEncoderDecoderModel
processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base-finetuned-docvqa")
model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base-finetuned-docvqa")
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
def process_document(image):
# prepare encoder inputs
pixel_values = processor(image, return_tensors="pt").pixel_values
# prepare decoder inputs
task_prompt = "<s_docvqa><s_question>{user_input}</s_question><s_answer>"
question = "When is the coffee break?"
prompt = task_prompt.replace("{user_input}", question)
decoder_input_ids = processor.tokenizer(prompt, add_special_tokens=False, return_tensors="pt").input_ids
# generate answer
outputs = model.generate(
pixel_values.to(device),
decoder_input_ids=decoder_input_ids.to(device),
max_length=model.decoder.config.max_position_embeddings,
early_stopping=True,
pad_token_id=processor.tokenizer.pad_token_id,
eos_token_id=processor.tokenizer.eos_token_id,
use_cache=True,
num_beams=1,
bad_words_ids=[[processor.tokenizer.unk_token_id]],
return_dict_in_generate=True,
)
# postprocess
sequence = processor.batch_decode(outputs.sequences)[0]
sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token
return processor.token2json(sequence)
image = Image.open("./example_1.png")
image.save("example_1.png")
demo = gr.Interface(
fn=process_document,
inputs= gr.inputs.Image(type="pil"),
outputs="json",
title=f"Interactive demo: Donut 🍩 for DocVQA",
description="""This model is fine-tuned on the DocVQA dataset. <br>
Documentation: https://huggingface.co/docs/transformers/main/en/model_doc/donut
Notebooks: https://github.com/NielsRogge/Transformers-Tutorials/tree/master/Donut
More details are available at:
- Paper: https://arxiv.org/abs/2111.15664
- Original repository: https://github.com/clovaai/donut""",
examples=[["example_1.png"]],
cache_examples=False,
)
demo.launch()