HUANG-Stephanie commited on
Commit
d005da4
1 Parent(s): 7e4dcaf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -38
app.py CHANGED
@@ -3,6 +3,13 @@ import sys
3
 
4
  import gradio as gr
5
  import torch
 
 
 
 
 
 
 
6
  from pdf2image import convert_from_path
7
  from PIL import Image
8
  from torch.utils.data import DataLoader
@@ -11,11 +18,23 @@ from transformers import AutoProcessor
11
 
12
  sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), './colpali-main')))
13
 
14
- from colpali_engine.models.paligemma_colbert_architecture import ColPali
15
- from colpali_engine.trainer.retrieval_evaluator import CustomEvaluator
16
- from colpali_engine.utils.colpali_processing_utils import process_images, process_queries
 
 
17
 
 
 
 
 
 
 
 
 
 
18
  def search(query: str, ds, images, k):
 
19
  qs = []
20
  with torch.no_grad():
21
  batch_query = process_queries(processor, [query], mock_image)
@@ -23,7 +42,6 @@ def search(query: str, ds, images, k):
23
  embeddings_query = model(**batch_query)
24
  qs.extend(list(torch.unbind(embeddings_query.to("cpu"))))
25
 
26
- # run evaluation
27
  retriever_evaluator = CustomEvaluator(is_multi_vector=True)
28
  scores = retriever_evaluator.evaluate(qs, ds)
29
 
@@ -36,19 +54,24 @@ def search(query: str, ds, images, k):
36
  return results
37
 
38
 
39
- def index(file, ds):
 
40
  """Example script to run inference with ColPali"""
41
  images = []
42
- for f in file:
43
  images.extend(convert_from_path(f))
44
 
 
 
 
45
  # run inference - docs
46
  dataloader = DataLoader(
47
  images,
48
  batch_size=4,
49
  shuffle=False,
50
  collate_fn=lambda x: process_images(processor, x),
51
- )
 
52
  for batch_doc in tqdm(dataloader):
53
  with torch.no_grad():
54
  batch_doc = {k: v.to(device) for k, v in batch_doc.items()}
@@ -56,42 +79,41 @@ def index(file, ds):
56
  ds.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
57
  return f"Uploaded and converted {len(images)} pages", ds, images
58
 
59
- COLORS = ["#4285f4", "#db4437", "#f4b400", "#0f9d58", "#e48ef1"]
60
- # Load model
61
- model_name = "vidore/colpali"
62
- token = os.environ.get("HF_TOKEN")
63
- model = ColPali.from_pretrained(
64
- "google/paligemma-3b-mix-448", torch_dtype=torch.bfloat16, device_map="cpu", token=token
65
- ).eval()
66
- model.load_adapter(model_name)
67
- processor = AutoProcessor.from_pretrained(model_name, token=token)
68
- device = "cuda:0" if torch.cuda.is_available() else "cpu"
69
- if device != model.device:
70
- model.to(device)
71
- mock_image = Image.new("RGB", (448, 448), (255, 255, 255))
72
 
73
- with gr.Blocks() as demo:
74
- gr.Markdown("# ColPali: Efficient Document Retrieval with Vision Language Models 📚🔍")
 
 
 
75
  gr.Markdown("## 1️⃣ Upload PDFs")
76
- file = gr.File(file_types=["pdf"], file_count="multiple")
77
 
78
- gr.Markdown("## 2️⃣ Convert the PDFs and upload")
79
- convert_button = gr.Button("🔄 Convert and upload")
80
- message = gr.Textbox("Files not yet uploaded")
81
- embeds = gr.State(value=[])
82
- imgs = gr.State(value=[])
83
 
84
- # Define the actions
85
- convert_button.click(index, inputs=[file, embeds], outputs=[message, embeds, imgs])
 
 
86
 
87
- gr.Markdown("## 3️⃣ Search")
88
- query = gr.Textbox(placeholder="Enter your query here")
89
- search_button = gr.Button("🔍 Search")
90
- message2 = gr.Textbox("Query not yet set")
91
- output_img = gr.Image()
92
- k = gr.Slider(minimum=1, maximum=10, step=1, label="Number of results", value=5)
93
 
94
- search_button.click(search, inputs=[query, embeds, imgs, k], outputs=[message2, output_img])
 
 
 
 
 
 
 
 
 
 
 
95
 
96
  if __name__ == "__main__":
97
- demo.queue(max_size=10).launch(debug=True)
 
3
 
4
  import gradio as gr
5
  import torch
6
+ from colpali_engine.models.paligemma_colbert_architecture import ColPali
7
+ from colpali_engine.trainer.retrieval_evaluator import CustomEvaluator
8
+ from colpali_engine.utils.colpali_processing_utils import (
9
+ process_images,
10
+ process_queries,
11
+ )
12
+ import spaces
13
  from pdf2image import convert_from_path
14
  from PIL import Image
15
  from torch.utils.data import DataLoader
 
18
 
19
  sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), './colpali-main')))
20
 
21
+ # Load model
22
+ model_name = "vidore/colpali"
23
+ token = os.environ.get("HF_TOKEN")
24
+ model = ColPali.from_pretrained(
25
+ "google/paligemma-3b-mix-448", torch_dtype=torch.bfloat16, device_map="cpu", token = token).eval()
26
 
27
+ model.load_adapter(model_name)
28
+ processor = AutoProcessor.from_pretrained(model_name, token = token)
29
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
30
+ if device != model.device:
31
+ model.to(device)
32
+ mock_image = Image.new("RGB", (448, 448), (255, 255, 255))
33
+
34
+
35
+ @spaces.GPU
36
  def search(query: str, ds, images, k):
37
+
38
  qs = []
39
  with torch.no_grad():
40
  batch_query = process_queries(processor, [query], mock_image)
 
42
  embeddings_query = model(**batch_query)
43
  qs.extend(list(torch.unbind(embeddings_query.to("cpu"))))
44
 
 
45
  retriever_evaluator = CustomEvaluator(is_multi_vector=True)
46
  scores = retriever_evaluator.evaluate(qs, ds)
47
 
 
54
  return results
55
 
56
 
57
+ @spaces.GPU
58
+ def index(files, ds):
59
  """Example script to run inference with ColPali"""
60
  images = []
61
+ for f in files:
62
  images.extend(convert_from_path(f))
63
 
64
+ if len(images) >= 150:
65
+ raise gr.Error("The number of images in the dataset should be less than 150.")
66
+
67
  # run inference - docs
68
  dataloader = DataLoader(
69
  images,
70
  batch_size=4,
71
  shuffle=False,
72
  collate_fn=lambda x: process_images(processor, x),
73
+ )
74
+
75
  for batch_doc in tqdm(dataloader):
76
  with torch.no_grad():
77
  batch_doc = {k: v.to(device) for k, v in batch_doc.items()}
 
79
  ds.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
80
  return f"Uploaded and converted {len(images)} pages", ds, images
81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
+ def get_example():
84
+ return [[["climate_youth_magazine.pdf"], "How much tropical forest is cut annually ?"]]
85
+
86
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
87
+ gr.Markdown("# ColPali: Efficient Document Retrieval with Vision Language Models 📚")
88
  gr.Markdown("## 1️⃣ Upload PDFs")
 
89
 
90
+ with gr.Row():
91
+ with gr.Column(scale=2):
92
+ gr.Markdown("## 1️⃣ Upload PDFs")
93
+ file = gr.File(file_types=["pdf"], file_count="multiple", label="Upload PDFs")
 
94
 
95
+ convert_button = gr.Button("🔄 Index documents")
96
+ message = gr.Textbox("Files not yet uploaded", label="Status")
97
+ embeds = gr.State(value=[])
98
+ imgs = gr.State(value=[])
99
 
100
+ with gr.Column(scale=3):
101
+ gr.Markdown("## 2️⃣ Search")
102
+ query = gr.Textbox(placeholder="Enter your query here", label="Query")
103
+ k = gr.Slider(minimum=1, maximum=10, step=1, label="Number of results", value=5)
 
 
104
 
105
+ # with gr.Row():
106
+ # gr.Examples(
107
+ # examples=get_example(),
108
+ # inputs=[file, query],
109
+ # )
110
+
111
+ # Define the actions
112
+ search_button = gr.Button("🔍 Search", variant="primary")
113
+ output_gallery = gr.Gallery(label="Retrieved Documents", height=600, show_label=True)
114
+
115
+ convert_button.click(index, inputs=[file, embeds], outputs=[message, embeds, imgs])
116
+ search_button.click(search, inputs=[query, embeds, imgs, k], outputs=[output_gallery])
117
 
118
  if __name__ == "__main__":
119
+ demo.queue(max_size=10).launch(debug=True)