Pranathi1 commited on
Commit
987f8b9
1 Parent(s): b98c13a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +164 -0
app.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from byaldi import RAGMultiModalModel
3
+ from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
4
+ from qwen_vl_utils import process_vision_info
5
+ import torch
6
+ from PIL import Image
7
+ import re
8
+
9
+ def highlight_text(text, term):
10
+ highlighted_text = re.sub(f"({term})", r'<mark>\1</mark>', text, flags=re.IGNORECASE)
11
+ return highlighted_text
12
+
13
+ @st.cache_resource
14
+ def load_models():
15
+ RAG = RAGMultiModalModel.from_pretrained("vidore/colpali")
16
+
17
+ model = Qwen2VLForConditionalGeneration.from_pretrained(
18
+ "Qwen/Qwen2-VL-2B-Instruct",
19
+ trust_remote_code=True,
20
+ torch_dtype=torch.bfloat16).cuda().eval()
21
+
22
+ processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", trust_remote_code=True)
23
+
24
+ return model, processor, RAG
25
+
26
+ if 'is_indexed' not in st.session_state:
27
+ st.session_state['is_indexed'] = False
28
+
29
+ st.title("Image to Text Extraction and Search with Highlighting")
30
+
31
+ uploaded_file = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg"])
32
+ if uploaded_file is not None:
33
+ # Save the uploaded image to a temporary file
34
+ temp_file_path = f"temp_{uploaded_file.name}"
35
+ with open(temp_file_path, "wb") as f:
36
+ f.write(uploaded_file.getbuffer())
37
+
38
+ image = Image.open(uploaded_file)
39
+ images = [image]
40
+ st.image(image, caption='Uploaded Image', use_column_width=True)
41
+
42
+ model, processor, RAG = load_models()
43
+
44
+ # Text Extraction from Image
45
+ messages = [
46
+ {
47
+ "role": "user",
48
+ "content": [
49
+ {
50
+ "type": "image",
51
+ "image": image,
52
+ },
53
+ {"type": "text", "text": "Extract the text from this image."},
54
+ ],
55
+ }
56
+ ]
57
+
58
+ # Process the image and text for input
59
+ text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
60
+ image_inputs, video_inputs = process_vision_info(messages)
61
+ inputs = processor(
62
+ text=[text],
63
+ images=image_inputs,
64
+ videos=video_inputs,
65
+ padding=True,
66
+ return_tensors="pt",
67
+ )
68
+ inputs = inputs.to("cuda")
69
+
70
+ # Generate the text from the image using the model
71
+ generated_ids = model.generate(**inputs, max_new_tokens=5000)
72
+
73
+ generated_ids_trimmed = [
74
+ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
75
+ ]
76
+ extracted_text = processor.batch_decode(
77
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
78
+ )
79
+ extracted_text = "\n".join(extracted_text) # Convert list to a single string
80
+
81
+ st.subheader("Extracted Text:")
82
+ st.write(extracted_text)
83
+
84
+ # Save the extracted text to a file
85
+ with open("extracted_text.txt", "w", encoding="utf-8") as f:
86
+ f.write(extracted_text)
87
+
88
+ # Search Query
89
+ query = st.text_input("Search in Extracted Text", "")
90
+
91
+ if query:
92
+ # If the query is a single word, highlight its occurrences
93
+ if len(query.split()) == 1:
94
+ # Highlight the search term in the extracted text
95
+ highlighted_text = highlight_text(extracted_text, query)
96
+ st.subheader("Search Result (Word Occurrences):")
97
+ st.markdown(highlighted_text, unsafe_allow_html=True)
98
+
99
+ # If the query is more than one word, use RAG for Intelli search
100
+ else:
101
+ # Only index the image once
102
+ if not st.session_state['is_indexed']:
103
+ try:
104
+ RAG.index(
105
+ input_path=temp_file_path, # Use the local file path for indexing
106
+ index_name="image_index", # index will be saved at index_root/index_name/
107
+ store_collection_with_index=False,
108
+ overwrite=True
109
+ )
110
+ st.session_state['is_indexed'] = True # Mark document as indexed
111
+ except Exception as e:
112
+ st.error(f"Error during indexing: {str(e)}")
113
+
114
+ # Perform search using the query
115
+ try:
116
+ results = RAG.search(query, k=1)
117
+ query_image_index = results[0]["page_num"] - 1
118
+
119
+ # Get the result text related to the query
120
+ query_messages = [
121
+ {
122
+ "role": "user",
123
+ "content": [
124
+ {
125
+ "type": "image",
126
+ "image": images[query_image_index],
127
+ },
128
+ {"type": "text", "text": query},
129
+ ],
130
+ }
131
+ ]
132
+
133
+ # Generate the answer using the RAG model
134
+ text = processor.apply_chat_template(
135
+ query_messages, tokenize=False, add_generation_prompt=True
136
+ )
137
+ image_inputs, video_inputs = process_vision_info(messages)
138
+ inputs = processor(
139
+ text=[text],
140
+ images=image_inputs,
141
+ videos=video_inputs,
142
+ padding=True,
143
+ return_tensors="pt",
144
+ )
145
+ inputs = inputs.to("cuda")
146
+
147
+ generated_ids_query = model.generate(**inputs, max_new_tokens=1000)
148
+ generated_ids_trimmed = [
149
+ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids_query)
150
+ ]
151
+ query_result = processor.batch_decode(
152
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
153
+ )
154
+
155
+ # Highlight the query within the result
156
+ highlighted_result = highlight_text("\n".join(query_result), query)
157
+
158
+ # Display the query result
159
+ st.subheader("Search Result (Intelli Answer):")
160
+ st.markdown(highlighted_result, unsafe_allow_html=True)
161
+
162
+ except Exception as e:
163
+ st.error(f"Error during search: {str(e)}")
164
+