File size: 4,274 Bytes
526fa39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import gradio as gr
from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
from PIL import Image
import requests
from byaldi import RAGMultiModalModel
from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
from PIL import Image
from io import BytesIO
import torch
import re
import base64

RAG = RAGMultiModalModel.from_pretrained("vidore/colpali", verbose=10)
model = Qwen2VLForConditionalGeneration.from_pretrained(
        "Qwen/Qwen2-VL-2B-Instruct",
        torch_dtype=torch.float16,
        device_map="auto",
    )
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")

def create_rag_index(image_path):
    RAG.index(
        input_path=image_path,
        index_name="image_index",
        store_collection_with_index=True,
        overwrite=True,
    )

def extract_relevant_text(qwen_output):
    # Extract the main content from the Qwen2-VL output (assuming it's a list of strings)
    qwen_text = qwen_output[0]

    # Split the text by newlines and periods to handle various sentence structures
    lines = qwen_text.split('\n')

    # Initialize a list to hold relevant text lines
    relevant_text = []

    # Loop through each line to identify relevant text
    for line in lines:
        # Use a regex to match text that looks like it's extracted from the image
        # We ignore any description or meta information
        if re.match(r'[A-Za-z0-9]', line):  # Matches lines that have words or numbers
            relevant_text.append(line.strip())

    # Join the relevant text into a single output (you can customize the format)
    return "\n".join(relevant_text)


# put all in one function
def ocr_image(image_path,text_query):
    if text_query:
      create_rag_index(image_path)
      results = RAG.search(text_query, k=1, return_base64_results=True)

      image_data = base64.b64decode(results[0].base64)
      image = Image.open(BytesIO(image_data))
    else:
      image = Image.open(image_path)
    messages = [
        {
            "role": "user",
            "content": [
                {
                    "type": "image",
                    "image": image,
                },
                {
                    "type": "text",
                    "text": "explain all text find in the image."
                }
            ]
        }
    ]

    text_prompt = processor.apply_chat_template(messages, add_generation_prompt=True)

    inputs = processor(
        text=[text_prompt],
        images=[image],
        padding=True,
        return_tensors="pt"
    )
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
    inputs = inputs.to(device)

    output_ids = model.generate(**inputs, max_new_tokens=1024)

    generated_ids = [
        output_ids[len(input_ids):]
        for input_ids, output_ids in zip(inputs.input_ids, output_ids)
    ]

    output_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)

    # Extract relevant text from the Qwen2-VL output
    relevant_text = extract_relevant_text(output_text)

    return relevant_text


def highlight_text(text, query):
    highlighted_text = text
    for word in query.split():
        pattern = re.compile(re.escape(word), re.IGNORECASE)
        highlighted_text = pattern.sub(lambda m: f'<span style="background-color: yellow;">{m.group()}</span>', highlighted_text)
    return highlighted_text

def ocr_and_search(image, keyword):
    extracted_text = ocr_image(image,keyword)
    #print(extracted_text)
    if keyword =='':
      return extracted_text , 'Please Enter a Keyword'

    else:
      highlighted_text = highlight_text(extracted_text, keyword)
    return extracted_text , highlighted_text

# Create Gradio Interface
interface = gr.Interface(
    fn=ocr_and_search,
    inputs=[
        gr.Image(type="filepath", label="Upload Image"),
        gr.Textbox(label="Enter Keyword")
    ],
    outputs=[
        gr.Textbox(label="Extracted Text"),
        gr.HTML("Search Result"),
    ],
    title="OCR and Document Search Web Application",
    description="Upload an image to extract text in Hindi and English and search for keywords."
)

if __name__ == "__main__":
    interface.launch(share=True)