ui-refexp / app.py
ivelin
fix: bugs
7b3f48a
raw
history blame
3.98 kB
import re
import gradio as gr
from PIL import Image, ImageDraw
import math
import torch
import html
import json
from transformers import DonutProcessor, VisionEncoderDecoderModel
pretrained_repo_name = "ivelin/donut-refexp-draft"
processor = DonutProcessor.from_pretrained(pretrained_repo_name)
model = VisionEncoderDecoderModel.from_pretrained(pretrained_repo_name)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
def process_refexp(image: Image, prompt: str):
print(f"(image, prompt): {image}, {prompt}")
# trim prompt to 80 characters and normalize to lowercase
prompt = prompt[:80].lower()
# prepare encoder inputs
pixel_values = processor(image, return_tensors="pt").pixel_values
# prepare decoder inputs
task_prompt = "<s_refexp><s_prompt>{user_input}</s_prompt><s_target_bounding_box>"
prompt = task_prompt.replace("{user_input}", prompt)
decoder_input_ids = processor.tokenizer(
prompt, add_special_tokens=False, return_tensors="pt").input_ids
# generate answer
outputs = model.generate(
pixel_values.to(device),
decoder_input_ids=decoder_input_ids.to(device),
max_length=model.decoder.config.max_position_embeddings,
early_stopping=True,
pad_token_id=processor.tokenizer.pad_token_id,
eos_token_id=processor.tokenizer.eos_token_id,
use_cache=True,
num_beams=1,
bad_words_ids=[[processor.tokenizer.unk_token_id]],
return_dict_in_generate=True,
)
# postprocess
sequence = processor.batch_decode(outputs.sequences)[0]
print(fr"predicted decoder sequence: {html.escape(sequence)}")
sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(
processor.tokenizer.pad_token, "")
# remove first task start token
sequence = re.sub(r"<.*?>", "", sequence, count=1).strip()
print(
fr"predicted decoder sequence before token2json: {html.escape(sequence)}")
bbox = processor.token2json(sequence)
bbox = json.loads(bbox)
print(f"predicted bounding box: {bbox}")
print(f"image object: {image}")
print(f"image size: {image.size}")
width, height = image.size
print(f"image width, height: {width, height}")
print(f"processed prompt: {prompt}")
xmin = math.floor(width*bbox["xmin"]) if bbox.get("xmin") else 0
ymin = math.floor(height*bbox["ymin"]) if bbox.get("ymin") else 0
xmax = math.floor(width*bbox["xmax"]) if bbox.get("xmax") else 1
ymax = math.floor(height*bbox["ymax"]) if bbox.get("ymax") else 1
print(
f"to image pixel values: xmin, ymin, xmax, ymax: {xmin, ymin, xmax, ymax}")
shape = [(xmin, ymin), (xmax, ymax)]
# create rectangle image
img1 = ImageDraw.Draw(image)
img1.rectangle(shape, outline="green", width=5)
return image, bbox
title = "Demo: Donut 🍩 for UI RefExp"
description = "Gradio Demo for Donut RefExp task, an instance of `VisionEncoderDecoderModel` fine-tuned on UIBert RefExp Dataset (UI Referring Expression). To use it, simply upload your image and type a question and click 'submit', or click one of the examples to load them. Read more at the links below."
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2111.15664' target='_blank'>Donut: OCR-free Document Understanding Transformer</a> | <a href='https://github.com/clovaai/donut' target='_blank'>Github Repo</a></p>"
examples = [["example_1.jpg", "select the setting icon from top right corner"],
["example_2.jpg", "enter the text field next to the name"]]
demo = gr.Interface(fn=process_refexp,
inputs=[gr.Image(type="pil"), "text"],
outputs=[gr.Image(type="pil"), "json"],
title=title,
description=description,
article=article,
examples=examples,
cache_examples=True
)
demo.launch()