RufusRubin777 commited on
Commit
d89f6ab
1 Parent(s): c39c19e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -33
app.py CHANGED
@@ -5,12 +5,12 @@ from byaldi import RAGMultiModalModel
5
  from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
6
  from qwen_vl_utils import process_vision_info
7
  import torch
 
8
 
9
  # Load models
10
  def load_models():
11
  RAG = RAGMultiModalModel.from_pretrained("vidore/colpali")
12
- model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-2B-Instruct",
13
- trust_remote_code=True, torch_dtype=torch.float32) # float32 for CPU
14
  processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", trust_remote_code=True)
15
  return RAG, model, processor
16
 
@@ -18,7 +18,6 @@ RAG, model, processor = load_models()
18
 
19
  # Function for OCR and search
20
  def ocr_and_search(image, keyword):
21
-
22
  text_query = "Extract all the text in Sanskrit and English from the image."
23
 
24
  # Prepare message for Qwen model
@@ -42,53 +41,49 @@ def ocr_and_search(image, keyword):
42
  padding=True,
43
  return_tensors="pt",
44
  ).to("cpu") # Use CPU
45
-
46
  # Generate text
47
  with torch.no_grad():
48
  generated_ids = model.generate(**inputs, max_new_tokens=2000)
49
- generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
50
- extracted_text = processor.batch_decode(
51
- generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
52
- )[0]
53
-
54
- # Save extracted text to JSON
55
- output_json = {"query": text_query, "extracted_text": extracted_text}
56
-
57
- # json_output = json.dumps(output_json, ensure_ascii=False, indent=4)
58
-
59
- gr.Textbox(label= extracted_text)
60
-
61
 
62
- # Perform keyword search
63
  keyword_lower = keyword.lower()
64
  sentences = extracted_text.split('. ')
65
- matched_sentences = [sentence for sentence in sentences if keyword_lower in sentence.lower()]
66
-
67
- gr.Textbox(label= matched_sentences)
 
 
 
 
 
 
 
68
 
69
- return extracted_text, matched_sentences #, json_output
70
-
71
 
72
- # Gradio App
73
  def app(image, keyword):
74
-
75
  extracted_text, search_results = ocr_and_search(image, keyword)
76
-
77
- search_results_str = "\n".join(search_results) if search_results else "No matches found."
78
-
79
- return extracted_text, search_results_str #, json_output
80
 
81
  # Gradio Interface
82
  iface = gr.Interface(
83
- fn=app,
84
  inputs=[
85
- gr.Image(type="pil", label="Upload an Image"),
86
  gr.Textbox(label="Enter keyword to search in extracted text", placeholder="Keyword")
87
- ],
88
  outputs=[
89
  gr.Textbox(label="Extracted Text"),
90
- gr.Textbox(label="Search Results"),
91
- # gr.JSON(label="JSON Output")
92
  ],
93
  title="OCR and Keyword Search in Images",
94
  )
 
5
  from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
6
  from qwen_vl_utils import process_vision_info
7
  import torch
8
+ import re
9
 
10
  # Load models
11
  def load_models():
12
  RAG = RAGMultiModalModel.from_pretrained("vidore/colpali")
13
+ model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", trust_remote_code=True, torch_dtype=torch.float32) # float32 for CPU
 
14
  processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", trust_remote_code=True)
15
  return RAG, model, processor
16
 
 
18
 
19
  # Function for OCR and search
20
  def ocr_and_search(image, keyword):
 
21
  text_query = "Extract all the text in Sanskrit and English from the image."
22
 
23
  # Prepare message for Qwen model
 
41
  padding=True,
42
  return_tensors="pt",
43
  ).to("cpu") # Use CPU
44
+
45
  # Generate text
46
  with torch.no_grad():
47
  generated_ids = model.generate(**inputs, max_new_tokens=2000)
48
+ generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
49
+ extracted_text = processor.batch_decode(
50
+ generated_ids_trimmed,
51
+ skip_special_tokens=True,
52
+ clean_up_tokenization_spaces=False
53
+ )[0]
 
 
 
 
 
 
54
 
55
+ # Perform keyword search with highlighting
56
  keyword_lower = keyword.lower()
57
  sentences = extracted_text.split('. ')
58
+ matched_sentences = []
59
+ for sentence in sentences:
60
+ if keyword_lower in sentence.lower():
61
+ highlighted_sentence = re.sub(
62
+ f'({re.escape(keyword)})',
63
+ r'<mark>\1</mark>',
64
+ sentence,
65
+ flags=re.IGNORECASE
66
+ )
67
+ matched_sentences.append(highlighted_sentence)
68
 
69
+ return extracted_text, matched_sentences
 
70
 
71
+ # Gradio App
72
  def app(image, keyword):
 
73
  extracted_text, search_results = ocr_and_search(image, keyword)
74
+ search_results_str = "<br>".join(search_results) if search_results else "No matches found."
75
+ return extracted_text, search_results_str
 
 
76
 
77
  # Gradio Interface
78
  iface = gr.Interface(
79
+ fn=app,
80
  inputs=[
81
+ gr.Image(type="pil", label="Upload an Image"),
82
  gr.Textbox(label="Enter keyword to search in extracted text", placeholder="Keyword")
83
+ ],
84
  outputs=[
85
  gr.Textbox(label="Extracted Text"),
86
+ gr.HTML(label="Search Results"),
 
87
  ],
88
  title="OCR and Keyword Search in Images",
89
  )