File size: 3,753 Bytes
ce8797e
 
 
 
 
 
 
d89f6ab
ce8797e
 
 
 
d89f6ab
ce8797e
 
 
 
 
eb0fa3b
 
ce8797e
eb0fa3b
ce8797e
eb0fa3b
 
 
ce8797e
eb0fa3b
 
 
 
ce8797e
eb0fa3b
ce8797e
 
 
 
5641add
ce8797e
eb0fa3b
ce8797e
 
 
5641add
eb0fa3b
 
916ad5f
 
eb0fa3b
 
ce8797e
5641add
d89f6ab
eb0fa3b
d89f6ab
 
eb0fa3b
d89f6ab
eb0fa3b
 
916ad5f
b295a57
eb0fa3b
 
916ad5f
ce8797e
eb0fa3b
 
 
5641add
916ad5f
d7b2a5e
 
a0bcd50
d7b2a5e
 
5641add
 
d7b2a5e
 
 
 
5641add
 
 
 
eb0fa3b
 
 
 
5641add
 
 
eb0fa3b
916ad5f
eb0fa3b
ce8797e
 
916ad5f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import gradio as gr
from PIL import Image
import json
from byaldi import RAGMultiModalModel
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
import torch
import re

# Load models
def load_models():
    RAG = RAGMultiModalModel.from_pretrained("vidore/colpali")
    model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", trust_remote_code=True, torch_dtype=torch.float32)  # float32 for CPU
    processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", trust_remote_code=True)
    return RAG, model, processor

RAG, model, processor = load_models()

# Function for OCR
def extract_text_from_image(image):
    text_query = "Extract all the text in Sanskrit and English from the image."
    
    # Prepare message for Qwen model
    messages = [ 
        { 
            "role": "user", 
            "content": [
                {"type": "image", "image": image}, 
                {"type": "text", "text": text_query}
            ] 
        } 
    ]
    
    # Process the image
    text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    image_inputs, video_inputs = process_vision_info(messages)
    inputs = processor(
        text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt"
    ).to("cpu")  # Use CPU
    
    # Generate text
    with torch.no_grad():
        generated_ids = model.generate(**inputs, max_new_tokens=2000)
        generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
        extracted_text = processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
    
    return extracted_text

# Function for keyword search
def search_keyword_in_text(extracted_text, keyword):
    keyword_lower = keyword.lower()
    sentences = extracted_text.split('. ')
    matched_sentences = []
    
    for sentence in sentences:
        if keyword_lower in sentence.lower():
            highlighted_sentence = re.sub(f'({re.escape(keyword)})', r'<mark>\1</mark>', sentence, flags=re.IGNORECASE)
            matched_sentences.append(highlighted_sentence)
    
    return matched_sentences if matched_sentences else ["No matches found."]

# Gradio App
def app_extract_text(image):
    extracted_text = extract_text_from_image(image)
    return extracted_text

def app_search_keyword(extracted_text, keyword):
    search_results = search_keyword_in_text(extracted_text, keyword)
    search_results_str = "<br>".join(search_results)
    return search_results_str


title_html = """
<h1><span class="gradient-text" id="text">IIT Roorkee - OCR and Document Search Web Application Prototype (ColPali implementation of the new Byaldi library + Huggingface transformers for Qwen2-VL.)</span></h1>
"""

# Gradio Interface
with gr.Blocks() as iface:

    gr.HTML(title_html)

    
    with gr.Row():
        with gr.Column():
            image_input = gr.Image(type="pil", label="Upload an Image")
            extract_button = gr.Button("Extract Text")
            extracted_text_output = gr.Textbox(label="Extracted Text")
            
            extract_button.click(app_extract_text, inputs=image_input, outputs=extracted_text_output)

        with gr.Column():
            keyword_input = gr.Textbox(label="Enter keyword to search in extracted text", placeholder="Keyword")
            search_button = gr.Button("Search Keyword")
            search_results_output = gr.HTML(label="Search Results")

            search_button.click(app_search_keyword, inputs=[extracted_text_output, keyword_input], outputs=search_results_output)

# Launch Gradio App
iface.launch()