|
import gradio as gr |
|
import torch, os, json, requests, sys |
|
from PIL import Image |
|
from transformers import DonutProcessor, VisionEncoderDecoderModel, VisionEncoderDecoderConfig |
|
from torchvision import transforms |
|
|
|
def load_image_from_URL(url): |
|
res = requests.get(url) |
|
|
|
if res.status_code == 200: |
|
img = Image.open(requests.get(url, stream = True).raw) |
|
|
|
if img.mode == "RGBA": |
|
img = img.convert("RGB") |
|
|
|
return img |
|
|
|
return None |
|
|
|
class OCRVQAModel(torch.nn.Module): |
|
def add_tokens(self, list_of_tokens): |
|
self.added_tokens.update(list_of_tokens) |
|
newly_added_num = self.processor.tokenizer.add_tokens(list_of_tokens) |
|
|
|
if newly_added_num > 0: |
|
self.donut.decoder.resize_token_embeddings(len(self.processor.tokenizer)) |
|
|
|
def __init__(self, config): |
|
super().__init__() |
|
|
|
self.model_name_or_path = config['donut'] |
|
self.processor_name_or_path = config['processor'] |
|
self.config_name_or_path = config['config'] |
|
|
|
self.donut_config = VisionEncoderDecoderConfig.from_pretrained(self.config_name_or_path) |
|
self.donut_config.encoder.image_size = [800, 600] |
|
self.donut_config.decoder.max_length = 64 |
|
|
|
self.processor = DonutProcessor.from_pretrained(self.processor_name_or_path) |
|
self.donut = VisionEncoderDecoderModel.from_pretrained(self.model_name_or_path, config = self.donut_config) |
|
|
|
self.added_tokens = set([]) |
|
self.setup() |
|
|
|
def setup(self): |
|
self.add_tokens(["<yes/>", "<no/>"]) |
|
self.processor.feature_extractor.size = self.donut_config.encoder.image_size[::-1] |
|
self.processor.feature_extractor.do_align_long_axis = False |
|
|
|
def inference(self, image, prompt, device): |
|
|
|
self.donut.eval() |
|
with torch.no_grad(): |
|
|
|
print(type(image), type(prompt), file = sys.stderr) |
|
image_ids = self.processor(image, return_tensors="pt").pixel_values.to(device) |
|
|
|
question = f'<s_docvqa><s_question>{prompt}</s_question><s_answer>' |
|
|
|
embedded_question = self.processor.tokenizer( |
|
question, |
|
add_special_tokens = False, |
|
return_tensors = "pt" |
|
)["input_ids"].to(device) |
|
|
|
outputs = self.donut.generate( |
|
image_ids, |
|
decoder_input_ids=embedded_question, |
|
max_length = self.donut.decoder.config.max_position_embeddings, |
|
early_stopping = True, |
|
pad_token_id = self.processor.tokenizer.pad_token_id, |
|
eos_token_id = self.processor.tokenizer.eos_token_id, |
|
use_cache = True, |
|
num_beams = 1, |
|
bad_words_ids = [ |
|
[self.processor.tokenizer.unk_token_id] |
|
], |
|
return_dict_in_generate = True |
|
) |
|
|
|
return self.processor.token2json(self.processor.batch_decode(outputs.sequences)[0]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model = OCRVQAModel({ |
|
"donut": "ndtran/donut_ocr-vqa-200k", |
|
"processor": "ndtran/donut_ocr-vqa-200k", |
|
"config": "naver-clova-ix/donut-base-finetuned-docvqa" |
|
}) |
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
model = model.to(device) |
|
|
|
def get_answer(image, url, question) -> str: |
|
global model, device |
|
|
|
if url is not None and (url.startswith('http') or url.startswith('https')): |
|
result = model.inference(load_image_from_URL(url), question, device) |
|
return result.get('answer', 'I don\'t know :<') |
|
|
|
result = model.inference(image, question, device) |
|
return result.get('answer', 'I don\'t know :<') |
|
|
|
|
|
with gr.Blocks() as demo: |
|
with gr.Row(): |
|
gr.Markdown( |
|
""" |
|
## Donut-OCR-VQA |
|
- This demo uses fine-tuned OCR-VQA-Donut model on the OCR-VQA-200k dataset to answer questions about images. |
|
|
|
## IO description |
|
- Input is an image or URL that represents a book cover (recommended) and a question that asks about information on the image. |
|
- Output: an answer to the question. |
|
""" |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
|
|
image = gr.Image(shape=(224, 224), type="pil", label="Pick an image") |
|
image_url = gr.Textbox(lines=1, label="Or use this option!", placeholder="Enter the image URL here") |
|
question = gr.Textbox(lines=5, label="Question") |
|
|
|
ask = gr.Button(label="Get the answer") |
|
|
|
with gr.Column(): |
|
answer = gr.Label(label="Answer") |
|
|
|
ask.click(get_answer, inputs=[image, image_url, question], outputs=[answer]) |
|
|
|
demo.launch() |