Spaces:
Runtime error
Runtime error
File size: 3,980 Bytes
f0a0a4a 671732c 93085f8 f0a0a4a 238e0cb 7b3f48a f0a0a4a fb4e118 186c0c1 a432919 186c0c1 ba6d9e2 f0a0a4a a432919 f0a0a4a 72d0321 e0dd23e f0a0a4a 2ea9225 f0a0a4a 2ea9225 e0dd23e 7b3f48a e0dd23e f0a0a4a a432919 e0dd23e a43a5b0 e0dd23e 7b3f48a e0dd23e a432919 e0dd23e f0a0a4a 7d6913a f0a0a4a e1b2bb3 ba6d9e2 a218a91 a432919 a218a91 3d33b8d a218a91 3d33b8d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 |
import re
import gradio as gr
from PIL import Image, ImageDraw
import math
import torch
import html
import json
from transformers import DonutProcessor, VisionEncoderDecoderModel
pretrained_repo_name = "ivelin/donut-refexp-draft"
processor = DonutProcessor.from_pretrained(pretrained_repo_name)
model = VisionEncoderDecoderModel.from_pretrained(pretrained_repo_name)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
def process_refexp(image: Image, prompt: str):
print(f"(image, prompt): {image}, {prompt}")
# trim prompt to 80 characters and normalize to lowercase
prompt = prompt[:80].lower()
# prepare encoder inputs
pixel_values = processor(image, return_tensors="pt").pixel_values
# prepare decoder inputs
task_prompt = "<s_refexp><s_prompt>{user_input}</s_prompt><s_target_bounding_box>"
prompt = task_prompt.replace("{user_input}", prompt)
decoder_input_ids = processor.tokenizer(
prompt, add_special_tokens=False, return_tensors="pt").input_ids
# generate answer
outputs = model.generate(
pixel_values.to(device),
decoder_input_ids=decoder_input_ids.to(device),
max_length=model.decoder.config.max_position_embeddings,
early_stopping=True,
pad_token_id=processor.tokenizer.pad_token_id,
eos_token_id=processor.tokenizer.eos_token_id,
use_cache=True,
num_beams=1,
bad_words_ids=[[processor.tokenizer.unk_token_id]],
return_dict_in_generate=True,
)
# postprocess
sequence = processor.batch_decode(outputs.sequences)[0]
print(fr"predicted decoder sequence: {html.escape(sequence)}")
sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(
processor.tokenizer.pad_token, "")
# remove first task start token
sequence = re.sub(r"<.*?>", "", sequence, count=1).strip()
print(
fr"predicted decoder sequence before token2json: {html.escape(sequence)}")
bbox = processor.token2json(sequence)
bbox = json.loads(bbox)
print(f"predicted bounding box: {bbox}")
print(f"image object: {image}")
print(f"image size: {image.size}")
width, height = image.size
print(f"image width, height: {width, height}")
print(f"processed prompt: {prompt}")
xmin = math.floor(width*bbox["xmin"]) if bbox.get("xmin") else 0
ymin = math.floor(height*bbox["ymin"]) if bbox.get("ymin") else 0
xmax = math.floor(width*bbox["xmax"]) if bbox.get("xmax") else 1
ymax = math.floor(height*bbox["ymax"]) if bbox.get("ymax") else 1
print(
f"to image pixel values: xmin, ymin, xmax, ymax: {xmin, ymin, xmax, ymax}")
shape = [(xmin, ymin), (xmax, ymax)]
# create rectangle image
img1 = ImageDraw.Draw(image)
img1.rectangle(shape, outline="green", width=5)
return image, bbox
title = "Demo: Donut 🍩 for UI RefExp"
description = "Gradio Demo for Donut RefExp task, an instance of `VisionEncoderDecoderModel` fine-tuned on UIBert RefExp Dataset (UI Referring Expression). To use it, simply upload your image and type a question and click 'submit', or click one of the examples to load them. Read more at the links below."
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2111.15664' target='_blank'>Donut: OCR-free Document Understanding Transformer</a> | <a href='https://github.com/clovaai/donut' target='_blank'>Github Repo</a></p>"
examples = [["example_1.jpg", "select the setting icon from top right corner"],
["example_2.jpg", "enter the text field next to the name"]]
demo = gr.Interface(fn=process_refexp,
inputs=[gr.Image(type="pil"), "text"],
outputs=[gr.Image(type="pil"), "json"],
title=title,
description=description,
article=article,
examples=examples,
cache_examples=True
)
demo.launch()
|