Qwen2VL-OCR_CPU / app.py
RufusRubin777's picture
Update app.py
5641add verified
raw
history blame
3.64 kB
import gradio as gr
from PIL import Image
import json
from byaldi import RAGMultiModalModel
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
import torch
import re
# Load models
def load_models():
RAG = RAGMultiModalModel.from_pretrained("vidore/colpali")
model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", trust_remote_code=True, torch_dtype=torch.float32) # float32 for CPU
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", trust_remote_code=True)
return RAG, model, processor
RAG, model, processor = load_models()
# Function for OCR extraction
def extract_text(image):
text_query = "Extract all the text in Sanskrit and English from the image."
# Prepare message for Qwen model
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": text_query},
],
}
]
# Process the image
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt"
).to("cpu") # Use CPU
# Generate text
with torch.no_grad():
generated_ids = model.generate(**inputs, max_new_tokens=2000)
generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
extracted_text = processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)[0]
return extracted_text
# Function for keyword search in extracted text
def search_text(extracted_text, keyword):
keyword_lower = keyword.lower()
sentences = extracted_text.split('. ')
matched_sentences = []
for sentence in sentences:
if keyword_lower in sentence.lower():
highlighted_sentence = re.sub(
f'({re.escape(keyword)})', r'<mark>\1</mark>', sentence, flags=re.IGNORECASE
)
matched_sentences.append(highlighted_sentence)
return matched_sentences
# Gradio App
def extract_text_app(image):
extracted_text = extract_text(image)
return extracted_text
def search_text_app(extracted_text, keyword):
search_results = search_text(extracted_text, keyword)
search_results_str = "<br>".join(search_results) if search_results else "No matches found."
return search_results_str
# Gradio Interface
with gr.Blocks() as iface:
extracted_text = gr.State()
with gr.Row():
with gr.Column():
image_input = gr.Image(type="pil", label="Upload an Image")
extract_button = gr.Button("Extract Text")
text_output = gr.Textbox(label="Extracted Text", interactive=False)
with gr.Column():
keyword_input = gr.Textbox(label="Enter keyword to search in extracted text", placeholder="Keyword")
search_button = gr.Button("Search Keyword")
search_output = gr.HTML(label="Search Results")
# Link the buttons to their respective functions
extract_button.click(fn=extract_text_app, inputs=image_input, outputs=text_output, _js=None)
extract_button.click(fn=lambda txt: txt, inputs=text_output, outputs=extracted_text)
search_button.click(fn=search_text_app, inputs=[extracted_text, keyword_input], outputs=search_output)
# Launch Gradio App
iface.launch()