File size: 3,980 Bytes
f0a0a4a
 
671732c
93085f8
f0a0a4a
238e0cb
7b3f48a
f0a0a4a
 
 
 
 
 
 
 
 
 
 
fb4e118
186c0c1
a432919
186c0c1
ba6d9e2
 
 
f0a0a4a
a432919
f0a0a4a
 
72d0321
e0dd23e
f0a0a4a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2ea9225
f0a0a4a
 
 
 
2ea9225
 
e0dd23e
7b3f48a
e0dd23e
f0a0a4a
a432919
 
 
e0dd23e
a43a5b0
e0dd23e
7b3f48a
 
 
 
e0dd23e
 
 
 
 
 
 
a432919
e0dd23e
 
f0a0a4a
 
7d6913a
f0a0a4a
 
e1b2bb3
 
ba6d9e2
a218a91
a432919
 
a218a91
 
 
3d33b8d
 
 
a218a91
3d33b8d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import re
import gradio as gr
from PIL import Image, ImageDraw
import math
import torch
import html
import json
from transformers import DonutProcessor, VisionEncoderDecoderModel

pretrained_repo_name = "ivelin/donut-refexp-draft"

processor = DonutProcessor.from_pretrained(pretrained_repo_name)
model = VisionEncoderDecoderModel.from_pretrained(pretrained_repo_name)

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)


def process_refexp(image: Image, prompt: str):

    print(f"(image, prompt): {image}, {prompt}")

    # trim prompt to 80 characters and normalize to lowercase
    prompt = prompt[:80].lower()

    # prepare encoder inputs
    pixel_values = processor(image, return_tensors="pt").pixel_values

    # prepare decoder inputs
    task_prompt = "<s_refexp><s_prompt>{user_input}</s_prompt><s_target_bounding_box>"
    prompt = task_prompt.replace("{user_input}", prompt)
    decoder_input_ids = processor.tokenizer(
        prompt, add_special_tokens=False, return_tensors="pt").input_ids

    # generate answer
    outputs = model.generate(
        pixel_values.to(device),
        decoder_input_ids=decoder_input_ids.to(device),
        max_length=model.decoder.config.max_position_embeddings,
        early_stopping=True,
        pad_token_id=processor.tokenizer.pad_token_id,
        eos_token_id=processor.tokenizer.eos_token_id,
        use_cache=True,
        num_beams=1,
        bad_words_ids=[[processor.tokenizer.unk_token_id]],
        return_dict_in_generate=True,
    )

    # postprocess
    sequence = processor.batch_decode(outputs.sequences)[0]
    print(fr"predicted decoder sequence: {html.escape(sequence)}")
    sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(
        processor.tokenizer.pad_token, "")
    # remove first task start token
    sequence = re.sub(r"<.*?>", "", sequence, count=1).strip()
    print(
        fr"predicted decoder sequence before token2json: {html.escape(sequence)}")
    bbox = processor.token2json(sequence)
    bbox = json.loads(bbox)
    print(f"predicted bounding box: {bbox}")

    print(f"image object: {image}")
    print(f"image size: {image.size}")
    width, height = image.size
    print(f"image width, height: {width, height}")
    print(f"processed prompt: {prompt}")

    xmin = math.floor(width*bbox["xmin"]) if bbox.get("xmin") else 0
    ymin = math.floor(height*bbox["ymin"]) if bbox.get("ymin") else 0
    xmax = math.floor(width*bbox["xmax"]) if bbox.get("xmax") else 1
    ymax = math.floor(height*bbox["ymax"]) if bbox.get("ymax") else 1

    print(
        f"to image pixel values: xmin, ymin, xmax, ymax: {xmin, ymin, xmax, ymax}")

    shape = [(xmin, ymin), (xmax, ymax)]

    # create rectangle image
    img1 = ImageDraw.Draw(image)
    img1.rectangle(shape, outline="green", width=5)
    return image, bbox


title = "Demo: Donut 🍩 for UI RefExp"
description = "Gradio Demo for Donut RefExp task, an instance of `VisionEncoderDecoderModel` fine-tuned on UIBert RefExp Dataset (UI Referring Expression). To use it, simply upload your image and type a question and click 'submit', or click one of the examples to load them. Read more at the links below."
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2111.15664' target='_blank'>Donut: OCR-free Document Understanding Transformer</a> | <a href='https://github.com/clovaai/donut' target='_blank'>Github Repo</a></p>"
examples = [["example_1.jpg", "select the setting icon from top right corner"],
            ["example_2.jpg", "enter the text field next to the name"]]

demo = gr.Interface(fn=process_refexp,
                    inputs=[gr.Image(type="pil"), "text"],
                    outputs=[gr.Image(type="pil"), "json"],
                    title=title,
                    description=description,
                    article=article,
                    examples=examples,
                    cache_examples=True
                    )

demo.launch()