ocr-vqa / app.py
ndtran's picture
Update app.py
cd98f10
import gradio as gr
import torch, os, json, requests, sys
from PIL import Image
from transformers import DonutProcessor, VisionEncoderDecoderModel, VisionEncoderDecoderConfig
from torchvision import transforms
def load_image_from_URL(url):
res = requests.get(url)
if res.status_code == 200:
img = Image.open(requests.get(url, stream = True).raw)
if img.mode == "RGBA":
img = img.convert("RGB")
return img
return None
class OCRVQAModel(torch.nn.Module):
def add_tokens(self, list_of_tokens):
self.added_tokens.update(list_of_tokens)
newly_added_num = self.processor.tokenizer.add_tokens(list_of_tokens)
if newly_added_num > 0:
self.donut.decoder.resize_token_embeddings(len(self.processor.tokenizer))
def __init__(self, config):
super().__init__()
self.model_name_or_path = config['donut']
self.processor_name_or_path = config['processor']
self.config_name_or_path = config['config']
self.donut_config = VisionEncoderDecoderConfig.from_pretrained(self.config_name_or_path)
self.donut_config.encoder.image_size = [800, 600]
self.donut_config.decoder.max_length = 64
self.processor = DonutProcessor.from_pretrained(self.processor_name_or_path)
self.donut = VisionEncoderDecoderModel.from_pretrained(self.model_name_or_path, config = self.donut_config)
self.added_tokens = set([])
self.setup()
def setup(self):
self.add_tokens(["<yes/>", "<no/>"])
self.processor.feature_extractor.size = self.donut_config.encoder.image_size[::-1]
self.processor.feature_extractor.do_align_long_axis = False
def inference(self, image, prompt, device):
# try:
self.donut.eval()
with torch.no_grad():
print(type(image), type(prompt), file = sys.stderr)
image_ids = self.processor(image, return_tensors="pt").pixel_values.to(device)
question = f'<s_docvqa><s_question>{prompt}</s_question><s_answer>'
embedded_question = self.processor.tokenizer(
question,
add_special_tokens = False,
return_tensors = "pt"
)["input_ids"].to(device)
outputs = self.donut.generate(
image_ids,
decoder_input_ids=embedded_question,
max_length = self.donut.decoder.config.max_position_embeddings,
early_stopping = True,
pad_token_id = self.processor.tokenizer.pad_token_id,
eos_token_id = self.processor.tokenizer.eos_token_id,
use_cache = True,
num_beams = 1,
bad_words_ids = [
[self.processor.tokenizer.unk_token_id]
],
return_dict_in_generate = True
)
return self.processor.token2json(self.processor.batch_decode(outputs.sequences)[0])
# except Exception as e:
# raise e
# return {
# 'question': prompt,
# 'answer': 'Some error occurred during inference time.'
# }
model = OCRVQAModel({
"donut": "ndtran/donut_ocr-vqa-200k",
"processor": "ndtran/donut_ocr-vqa-200k",
"config": "naver-clova-ix/donut-base-finetuned-docvqa"
})
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
def get_answer(image, url, question) -> str:
global model, device
if url is not None and (url.startswith('http') or url.startswith('https')):
result = model.inference(load_image_from_URL(url), question, device)
return result.get('answer', 'I don\'t know :<')
result = model.inference(image, question, device)
return result.get('answer', 'I don\'t know :<')
with gr.Blocks() as demo:
with gr.Row():
gr.Markdown(
"""
## Donut-OCR-VQA
- This demo uses fine-tuned OCR-VQA-Donut model on the OCR-VQA-200k dataset to answer questions about images.
## IO description
- Input is an image or URL that represents a book cover (recommended) and a question that asks about information on the image.
- Output: an answer to the question.
"""
)
with gr.Row():
with gr.Column():
image = gr.Image(shape=(224, 224), type="pil", label="Pick an image")
image_url = gr.Textbox(lines=1, label="Or use this option!", placeholder="Enter the image URL here")
question = gr.Textbox(lines=5, label="Question")
ask = gr.Button(label="Get the answer")
with gr.Column():
answer = gr.Label(label="Answer")
ask.click(get_answer, inputs=[image, image_url, question], outputs=[answer])
demo.launch()