rk404 commited on
Commit
2b565b6
1 Parent(s): 7e2652e

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +142 -0
app.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from PIL import Image
3
+ from pdf2image import convert_from_path
4
+ from byaldi import RAGMultiModalModel
5
+ from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
6
+ from qwen_vl_utils import process_vision_info
7
+ import torch
8
+ import time # For generating unique index names
9
+ import json
10
+ import re
11
+
12
+ device = "cuda" if torch.cuda.is_available() else "cpu"
13
+
14
+ # Initialize Qwen2-VL model and processor
15
+ @st.cache_resource
16
+ def load_models():
17
+ # Load RAG MultiModalModel and Qwen2-VL model
18
+ RAG = RAGMultiModalModel.from_pretrained("vidore/colpali")
19
+
20
+ model = Qwen2VLForConditionalGeneration.from_pretrained(
21
+ "Qwen/Qwen2-VL-7B-Instruct",
22
+ trust_remote_code=True,
23
+ torch_dtype=torch.bfloat16
24
+ ).to(device).eval()
25
+
26
+ processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", trust_remote_code=True)
27
+
28
+ return RAG, model, processor
29
+
30
+ RAG, model, processor = load_models()
31
+
32
+ # Step 1: Upload the file
33
+ st.title("OCR extraction")
34
+ uploaded_file = st.file_uploader("Upload a PDF or Image", type=["pdf", "png", "jpg", "jpeg"])
35
+
36
+ # Initialize a session state to store extracted text so it persists across reruns
37
+ if "extracted_text" not in st.session_state:
38
+ st.session_state.extracted_text = None
39
+
40
+ if uploaded_file is not None:
41
+ file_type = uploaded_file.name.split('.')[-1].lower()
42
+
43
+ # Step 2: Convert PDF to image (if the input is a PDF)
44
+ if file_type == "pdf":
45
+ st.write("Converting PDF to image...")
46
+ images = convert_from_path(uploaded_file)
47
+ image_to_process = images[0]
48
+ else:
49
+ # For images (png/jpg), just open the image directly
50
+ image_to_process = Image.open(uploaded_file)
51
+
52
+ # Step 3: Display the uploaded image or PDF
53
+ st.image(image_to_process, caption="Uploaded document", use_column_width=True)
54
+
55
+ # Step 4: Dynamically create a unique index name using timestamp
56
+ unique_index_name = f"image_index_{int(time.time())}" # Generate unique index name using current timestamp
57
+
58
+ # Step 5: Perform text extraction only if it's a new file
59
+ if st.session_state.extracted_text is None:
60
+ st.write(f"Indexing document with RAG (index name: {unique_index_name})...")
61
+ image_path = "uploaded_image.png" # Temporary save path
62
+ image_to_process.save(image_path)
63
+
64
+ RAG.index(
65
+ input_path=image_path,
66
+ index_name=unique_index_name, # Use unique index name
67
+ store_collection_with_index=False,
68
+ overwrite=False
69
+ )
70
+
71
+ # Step 6: Perform text extraction
72
+ text_query = "Extract all english text and hindi text from the document"
73
+ st.write("Searching the document using RAG...")
74
+ results = RAG.search(text_query, k=1)
75
+
76
+ # Prepare the messages for text and image input
77
+ messages = [
78
+ {
79
+ "role": "user",
80
+ "content": [
81
+ {"type": "image", "image": image_to_process},
82
+ {"type": "text", "text": text_query},
83
+ ],
84
+ }
85
+ ]
86
+
87
+ # Prepare and process image and text inputs
88
+ text_input = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
89
+ image_inputs, video_inputs = process_vision_info(messages)
90
+
91
+ inputs = processor(
92
+ text=[text_input],
93
+ images=image_inputs,
94
+ videos=video_inputs,
95
+ padding=True,
96
+ return_tensors="pt",
97
+ )
98
+
99
+ inputs = inputs.to(device)
100
+
101
+ # Generate text output from the image using Qwen2-VL
102
+ st.write("Generating text...")
103
+ generated_ids = model.generate(**inputs, max_new_tokens=100)
104
+ generated_ids_trimmed = [
105
+ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
106
+ ]
107
+
108
+ output_text = processor.batch_decode(
109
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
110
+ )
111
+
112
+ # Step 7: Store the extracted text in session state
113
+ st.session_state.extracted_text = output_text[0]
114
+
115
+ # Step 8: Display the extracted text in JSON format
116
+ extracted_text = st.session_state.extracted_text
117
+ structured_text = {"extracted_text": extracted_text}
118
+
119
+ st.subheader("Extracted Text (JSON Format):")
120
+ st.json(structured_text)
121
+
122
+ # Step 9: Implement a search functionality on already extracted text
123
+ if st.session_state.extracted_text:
124
+ with st.form(key='search_form'):
125
+ search_query = st.text_input("Enter keyword to search within the extracted text:")
126
+ search_button = st.form_submit_button("Search")
127
+
128
+ if search_button and search_query:
129
+ # Perform case-insensitive search and highlight the matches
130
+ extracted_text = st.session_state.extracted_text # Use already extracted text
131
+ matches = re.finditer(re.escape(search_query), extracted_text, re.IGNORECASE)
132
+
133
+ highlighted_text = extracted_text
134
+ result = ''
135
+ for match in matches:
136
+ start, end = match.span()
137
+ result = "**" + highlighted_text[start:end] + "**"
138
+
139
+ st.subheader("Search Results:")
140
+ if result == '':
141
+ st.markdown('Not forund')
142
+ st.markdown(result)