Qwen2VL-OCR_CPU / app.py
RufusRubin777's picture
Update app.py
d89f6ab verified
raw
history blame
3.03 kB
import gradio as gr
from PIL import Image
import json
from byaldi import RAGMultiModalModel
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
import torch
import re
# Load models
def load_models():
RAG = RAGMultiModalModel.from_pretrained("vidore/colpali")
model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", trust_remote_code=True, torch_dtype=torch.float32) # float32 for CPU
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", trust_remote_code=True)
return RAG, model, processor
RAG, model, processor = load_models()
# Function for OCR and search
def ocr_and_search(image, keyword):
text_query = "Extract all the text in Sanskrit and English from the image."
# Prepare message for Qwen model
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": text_query},
],
}
]
# Process the image
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
).to("cpu") # Use CPU
# Generate text
with torch.no_grad():
generated_ids = model.generate(**inputs, max_new_tokens=2000)
generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
extracted_text = processor.batch_decode(
generated_ids_trimmed,
skip_special_tokens=True,
clean_up_tokenization_spaces=False
)[0]
# Perform keyword search with highlighting
keyword_lower = keyword.lower()
sentences = extracted_text.split('. ')
matched_sentences = []
for sentence in sentences:
if keyword_lower in sentence.lower():
highlighted_sentence = re.sub(
f'({re.escape(keyword)})',
r'<mark>\1</mark>',
sentence,
flags=re.IGNORECASE
)
matched_sentences.append(highlighted_sentence)
return extracted_text, matched_sentences
# Gradio App
def app(image, keyword):
extracted_text, search_results = ocr_and_search(image, keyword)
search_results_str = "<br>".join(search_results) if search_results else "No matches found."
return extracted_text, search_results_str
# Gradio Interface
iface = gr.Interface(
fn=app,
inputs=[
gr.Image(type="pil", label="Upload an Image"),
gr.Textbox(label="Enter keyword to search in extracted text", placeholder="Keyword")
],
outputs=[
gr.Textbox(label="Extracted Text"),
gr.HTML(label="Search Results"),
],
title="OCR and Keyword Search in Images",
)
# Launch Gradio App
iface.launch()