HUANG-Stephanie commited on
Commit
3a0c450
1 Parent(s): 5d3c3b6

Update colpali-main/demo/app.py

Browse files
Files changed (1) hide show
  1. colpali-main/demo/app.py +6 -17
colpali-main/demo/app.py CHANGED
@@ -1,5 +1,4 @@
1
  import os
2
- import sys
3
 
4
  import gradio as gr
5
  import torch
@@ -9,13 +8,12 @@ from torch.utils.data import DataLoader
9
  from tqdm import tqdm
10
  from transformers import AutoProcessor
11
 
12
- sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
13
-
14
  from colpali_engine.models.paligemma_colbert_architecture import ColPali
15
  from colpali_engine.trainer.retrieval_evaluator import CustomEvaluator
16
  from colpali_engine.utils.colpali_processing_utils import process_images, process_queries
17
 
18
- def search(query: str, ds, images, k):
 
19
  qs = []
20
  with torch.no_grad():
21
  batch_query = process_queries(processor, [query], mock_image)
@@ -26,17 +24,8 @@ def search(query: str, ds, images, k):
26
  # run evaluation
27
  retriever_evaluator = CustomEvaluator(is_multi_vector=True)
28
  scores = retriever_evaluator.evaluate(qs, ds)
29
-
30
- top_k_indices = scores.argsort(axis=1)[0][-k:][::-1]
31
-
32
- results = []
33
- for idx in top_k_indices:
34
- results.append((images[idx], f"Page {idx}"))
35
-
36
- return results
37
-
38
- #best_page = int(scores.argmax(axis=1).item())
39
- #return f"The most relevant page is {best_page}", images[best_page]
40
 
41
 
42
  def index(file, ds):
@@ -59,6 +48,7 @@ def index(file, ds):
59
  ds.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
60
  return f"Uploaded and converted {len(images)} pages", ds, images
61
 
 
62
  COLORS = ["#4285f4", "#db4437", "#f4b400", "#0f9d58", "#e48ef1"]
63
  # Load model
64
  model_name = "vidore/colpali"
@@ -90,9 +80,8 @@ with gr.Blocks() as demo:
90
  search_button = gr.Button("🔍 Search")
91
  message2 = gr.Textbox("Query not yet set")
92
  output_img = gr.Image()
93
- k = gr.Slider(minimum=1, maximum=10, step=1, label="Number of results", value=5)
94
 
95
- search_button.click(search, inputs=[query, embeds, imgs, k], outputs=[message2, output_img])
96
 
97
 
98
  if __name__ == "__main__":
 
1
  import os
 
2
 
3
  import gradio as gr
4
  import torch
 
8
  from tqdm import tqdm
9
  from transformers import AutoProcessor
10
 
 
 
11
  from colpali_engine.models.paligemma_colbert_architecture import ColPali
12
  from colpali_engine.trainer.retrieval_evaluator import CustomEvaluator
13
  from colpali_engine.utils.colpali_processing_utils import process_images, process_queries
14
 
15
+
16
+ def search(query: str, ds, images):
17
  qs = []
18
  with torch.no_grad():
19
  batch_query = process_queries(processor, [query], mock_image)
 
24
  # run evaluation
25
  retriever_evaluator = CustomEvaluator(is_multi_vector=True)
26
  scores = retriever_evaluator.evaluate(qs, ds)
27
+ best_page = int(scores.argmax(axis=1).item())
28
+ return f"The most relevant page is {best_page}", images[best_page]
 
 
 
 
 
 
 
 
 
29
 
30
 
31
  def index(file, ds):
 
48
  ds.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
49
  return f"Uploaded and converted {len(images)} pages", ds, images
50
 
51
+
52
  COLORS = ["#4285f4", "#db4437", "#f4b400", "#0f9d58", "#e48ef1"]
53
  # Load model
54
  model_name = "vidore/colpali"
 
80
  search_button = gr.Button("🔍 Search")
81
  message2 = gr.Textbox("Query not yet set")
82
  output_img = gr.Image()
 
83
 
84
+ search_button.click(search, inputs=[query, embeds, imgs], outputs=[message2, output_img])
85
 
86
 
87
  if __name__ == "__main__":