File size: 5,397 Bytes
2b565b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
07999ad
 
 
 
 
 
9c1ab38
07999ad
 
 
 
 
 
 
 
 
2b565b6
 
501b6f6
2b565b6
501b6f6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import streamlit as st
from PIL import Image
from pdf2image import convert_from_path
from byaldi import RAGMultiModalModel
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
import torch
import time  # For generating unique index names
import json
import re

device = "cuda" if torch.cuda.is_available() else "cpu"

# Initialize Qwen2-VL model and processor
@st.cache_resource
def load_models():
    # Load RAG MultiModalModel and Qwen2-VL model
    RAG = RAGMultiModalModel.from_pretrained("vidore/colpali")
    
    model = Qwen2VLForConditionalGeneration.from_pretrained(
        "Qwen/Qwen2-VL-7B-Instruct",
        trust_remote_code=True,
        torch_dtype=torch.bfloat16
    ).to(device).eval()

    processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", trust_remote_code=True)

    return RAG, model, processor

RAG, model, processor = load_models()

# Step 1: Upload the file
st.title("OCR extraction")
uploaded_file = st.file_uploader("Upload a PDF or Image", type=["pdf", "png", "jpg", "jpeg"])

# Initialize a session state to store extracted text so it persists across reruns
if "extracted_text" not in st.session_state:
    st.session_state.extracted_text = None

if uploaded_file is not None:
    file_type = uploaded_file.name.split('.')[-1].lower()

    # Step 2: Convert PDF to image (if the input is a PDF)
    if file_type == "pdf":
        st.write("Converting PDF to image...")
        images = convert_from_path(uploaded_file)
        image_to_process = images[0]
    else:
        # For images (png/jpg), just open the image directly
        image_to_process = Image.open(uploaded_file)

    # Step 3: Display the uploaded image or PDF
    st.image(image_to_process, caption="Uploaded document", use_column_width=True)

    # Step 4: Dynamically create a unique index name using timestamp
    unique_index_name = f"image_index_{int(time.time())}"  # Generate unique index name using current timestamp

    # Step 5: Perform text extraction only if it's a new file
    if st.session_state.extracted_text is None:
        st.write(f"Indexing document with RAG (index name: {unique_index_name})...")
        image_path = "uploaded_image.png"  # Temporary save path
        image_to_process.save(image_path)
        
        RAG.index(
            input_path=image_path,
            index_name=unique_index_name,  # Use unique index name
            store_collection_with_index=False,
            overwrite=False
        )

        # Step 6: Perform text extraction
        text_query = "Extract all english text and hindi text from the document"
        st.write("Searching the document using RAG...")
        results = RAG.search(text_query, k=1)

        # Prepare the messages for text and image input
        messages = [
            {
                "role": "user",
                "content": [
                    {"type": "image", "image": image_to_process},
                    {"type": "text", "text": text_query},
                ],
            }
        ]

        # Prepare and process image and text inputs
        text_input = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        image_inputs, video_inputs = process_vision_info(messages)

        inputs = processor(
            text=[text_input],
            images=image_inputs,
            videos=video_inputs,
            padding=True,
            return_tensors="pt",
        )

        inputs = inputs.to(device)

        # Generate text output from the image using Qwen2-VL
        st.write("Generating text...")
        generated_ids = model.generate(**inputs, max_new_tokens=100)
        generated_ids_trimmed = [
            out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
        ]

        output_text = processor.batch_decode(
            generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
        )

        # Step 7: Store the extracted text in session state
        st.session_state.extracted_text = output_text[0]

    # Step 8: Display the extracted text in JSON format
    extracted_text = st.session_state.extracted_text
    structured_text = {"extracted_text": extracted_text}

    st.subheader("Extracted Text (JSON Format):")
    st.json(structured_text)

# Step 9: Implement a search functionality on already extracted text
if st.session_state.extracted_text:
    with st.form(key='text_search_form'):
        search_input = st.text_input("Enter a keyword to search within the extracted text:")
        search_action = st.form_submit_button("Search")

    if search_action and search_input:
        # Split the extracted text into lines for searching
        full_text = st.session_state.extracted_text
        lines = full_text.split('\n')

        results = []
        # Search for keyword in each line and collect lines that contain the keyword
        for line in lines:
            if re.search(re.escape(search_input), line, re.IGNORECASE):
                # Highlight keyword in the line
                highlighted_line = re.sub(f"({re.escape(search_input)})", r"*\1*", line, flags=re.IGNORECASE)
                results.append(highlighted_line)
        
        st.subheader("Search Results:")
        if results == []:
            st.markdown('Not forund')
        st.markdown(results)