Qwen2VL-OCR_CPU / app.py
RufusRubin777's picture
Update app.py
a0bcd50 verified
raw
history blame
3.75 kB
import gradio as gr
from PIL import Image
import json
from byaldi import RAGMultiModalModel
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
import torch
import re
# Load models
def load_models():
RAG = RAGMultiModalModel.from_pretrained("vidore/colpali")
model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", trust_remote_code=True, torch_dtype=torch.float32) # float32 for CPU
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", trust_remote_code=True)
return RAG, model, processor
RAG, model, processor = load_models()
# Function for OCR
def extract_text_from_image(image):
text_query = "Extract all the text in Sanskrit and English from the image."
# Prepare message for Qwen model
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": text_query}
]
}
]
# Process the image
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt"
).to("cpu") # Use CPU
# Generate text
with torch.no_grad():
generated_ids = model.generate(**inputs, max_new_tokens=2000)
generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
extracted_text = processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
return extracted_text
# Function for keyword search
def search_keyword_in_text(extracted_text, keyword):
keyword_lower = keyword.lower()
sentences = extracted_text.split('. ')
matched_sentences = []
for sentence in sentences:
if keyword_lower in sentence.lower():
highlighted_sentence = re.sub(f'({re.escape(keyword)})', r'<mark>\1</mark>', sentence, flags=re.IGNORECASE)
matched_sentences.append(highlighted_sentence)
return matched_sentences if matched_sentences else ["No matches found."]
# Gradio App
def app_extract_text(image):
extracted_text = extract_text_from_image(image)
return extracted_text
def app_search_keyword(extracted_text, keyword):
search_results = search_keyword_in_text(extracted_text, keyword)
search_results_str = "<br>".join(search_results)
return search_results_str
title_html = """
<h1><span class="gradient-text" id="text">IIT Roorkee - OCR and Document Search Web Application Prototype (ColPali implementation of the new Byaldi library + Huggingface transformers for Qwen2-VL.)</span></h1>
"""
# Gradio Interface
with gr.Blocks() as iface:
gr.HTML(title_html)
with gr.Row():
with gr.Column():
image_input = gr.Image(type="pil", label="Upload an Image")
extract_button = gr.Button("Extract Text")
extracted_text_output = gr.Textbox(label="Extracted Text")
extract_button.click(app_extract_text, inputs=image_input, outputs=extracted_text_output)
with gr.Column():
keyword_input = gr.Textbox(label="Enter keyword to search in extracted text", placeholder="Keyword")
search_button = gr.Button("Search Keyword")
search_results_output = gr.HTML(label="Search Results")
search_button.click(app_search_keyword, inputs=[extracted_text_output, keyword_input], outputs=search_results_output)
# Launch Gradio App
iface.launch()