HUANG-Stephanie commited on
Commit
33de980
1 Parent(s): b6cdf0d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -75
app.py CHANGED
@@ -1,22 +1,25 @@
 
1
  import os
2
  import sys
3
 
 
 
 
 
 
 
 
 
4
  sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), './colpali-main')))
5
 
6
- import gradio as gr
7
- import torch
8
  from colpali_engine.models.paligemma_colbert_architecture import ColPali
9
  from colpali_engine.trainer.retrieval_evaluator import CustomEvaluator
10
  from colpali_engine.utils.colpali_processing_utils import (
11
  process_images,
12
  process_queries,
13
  )
14
- import spaces
15
- from pdf2image import convert_from_path
16
- from PIL import Image
17
- from torch.utils.data import DataLoader
18
- from tqdm import tqdm
19
- from transformers import AutoProcessor
20
 
21
  # Load model
22
  model_name = "vidore/colpali"
@@ -31,88 +34,52 @@ if device != model.device:
31
  model.to(device)
32
  mock_image = Image.new("RGB", (448, 448), (255, 255, 255))
33
 
 
 
 
34
 
35
- @spaces.GPU
36
- def search(query: str, ds, images, k):
37
-
38
- qs = []
39
- with torch.no_grad():
40
- batch_query = process_queries(processor, [query], mock_image)
41
- batch_query = {k: v.to(device) for k, v in batch_query.items()}
42
- embeddings_query = model(**batch_query)
43
- qs.extend(list(torch.unbind(embeddings_query.to("cpu"))))
44
-
45
- retriever_evaluator = CustomEvaluator(is_multi_vector=True)
46
- scores = retriever_evaluator.evaluate(qs, ds)
47
-
48
- top_k_indices = scores.argsort(axis=1)[0][-k:][::-1]
49
-
50
- results = []
51
- for idx in top_k_indices:
52
- results.append((images[idx], f"Page {idx}"))
53
-
54
- return results
55
-
56
-
57
- @spaces.GPU
58
- def index(files, ds):
59
- """Example script to run inference with ColPali"""
60
  images = []
61
- for f in files:
62
- images.extend(convert_from_path(f))
63
-
64
- if len(images) >= 150:
65
- raise gr.Error("The number of images in the dataset should be less than 150.")
66
-
67
- # run inference - docs
68
  dataloader = DataLoader(
69
  images,
70
  batch_size=4,
71
  shuffle=False,
72
  collate_fn=lambda x: process_images(processor, x),
73
- )
74
-
75
- for batch_doc in tqdm(dataloader):
76
  with torch.no_grad():
77
  batch_doc = {k: v.to(device) for k, v in batch_doc.items()}
78
  embeddings_doc = model(**batch_doc)
79
  ds.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
80
- return f"Uploaded and converted {len(images)} pages", ds, images
81
-
82
-
83
- def get_example():
84
- return [[["climate_youth_magazine.pdf"], "How much tropical forest is cut annually ?"]]
85
-
86
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
87
- gr.Markdown("# ColPali: Efficient Document Retrieval with Vision Language Models 📚")
88
 
89
- with gr.Row():
90
- with gr.Column(scale=2):
91
- gr.Markdown("## 1️⃣ Upload PDFs")
92
- file = gr.File(file_types=["pdf"], file_count="multiple", label="Upload PDFs")
93
-
94
- convert_button = gr.Button("🔄 Index documents")
95
- message = gr.Textbox("Files not yet uploaded", label="Status")
96
- embeds = gr.State(value=[])
97
- imgs = gr.State(value=[])
98
 
99
- with gr.Column(scale=3):
100
- gr.Markdown("## 2️⃣ Search")
101
- query = gr.Textbox(placeholder="Enter your query here", label="Query")
102
- k = gr.Slider(minimum=1, maximum=10, step=1, label="Number of results", value=5)
103
 
104
- # with gr.Row():
105
- # gr.Examples(
106
- # examples=get_example(),
107
- # inputs=[file, query],
108
- # )
109
 
110
- # Define the actions
111
- search_button = gr.Button("🔍 Search", variant="primary")
112
- output_gallery = gr.Gallery(label="Retrieved Documents", height=600, show_label=True)
113
 
114
- convert_button.click(index, inputs=[file, embeds], outputs=[message, embeds, imgs])
115
- search_button.click(search, inputs=[query, embeds, imgs, k], outputs=[output_gallery])
116
 
117
  if __name__ == "__main__":
118
- demo.queue(max_size=10).launch(debug=True)
 
 
1
+ import io
2
  import os
3
  import sys
4
 
5
+ from fastapi import FastAPI, File, UploadFile
6
+ from typing import List
7
+ import torch
8
+ from pdf2image import convert_from_path
9
+ from PIL import Image
10
+ from torch.utils.data import DataLoader
11
+ from transformers import AutoProcessor
12
+
13
  sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), './colpali-main')))
14
 
 
 
15
  from colpali_engine.models.paligemma_colbert_architecture import ColPali
16
  from colpali_engine.trainer.retrieval_evaluator import CustomEvaluator
17
  from colpali_engine.utils.colpali_processing_utils import (
18
  process_images,
19
  process_queries,
20
  )
21
+
22
+ app = FastAPI()
 
 
 
 
23
 
24
  # Load model
25
  model_name = "vidore/colpali"
 
34
  model.to(device)
35
  mock_image = Image.new("RGB", (448, 448), (255, 255, 255))
36
 
37
+ # In-memory storage
38
+ ds = []
39
+ images = []
40
 
41
+ @app.post("/index")
42
+ async def index(files: List[UploadFile] = File(...)):
43
+ global ds, images
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  images = []
45
+ ds = []
46
+ for file in files:
47
+ content = await file.read()
48
+ pdf_image_list = convert_from_path(io.BytesIO(content))
49
+ images.extend(pdf_image_list)
50
+
 
51
  dataloader = DataLoader(
52
  images,
53
  batch_size=4,
54
  shuffle=False,
55
  collate_fn=lambda x: process_images(processor, x),
56
+ )
57
+ for batch_doc in dataloader:
 
58
  with torch.no_grad():
59
  batch_doc = {k: v.to(device) for k, v in batch_doc.items()}
60
  embeddings_doc = model(**batch_doc)
61
  ds.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
62
+
63
+ return {"message": f"Uploaded and converted {len(images)} pages"}
 
 
 
 
 
 
64
 
65
+ @app.post("/search")
66
+ async def search(query: str, k: int):
67
+ qs = []
68
+ with torch.no_grad():
69
+ batch_query = process_queries(processor, [query], mock_image)
70
+ batch_query = {k: v.to(device) for k, v in batch_query.items()}
71
+ embeddings_query = model(**batch_query)
72
+ qs.extend(list(torch.unbind(embeddings_query.to("cpu"))))
 
73
 
74
+ retriever_evaluator = CustomEvaluator(is_multi_vector=True)
75
+ scores = retriever_evaluator.evaluate(qs, ds)
 
 
76
 
77
+ top_k_indices = scores.argsort(axis=1)[0][-k:][::-1]
 
 
 
 
78
 
79
+ results = [{"page": idx, "image": "image_placeholder"} for idx in top_k_indices]
 
 
80
 
81
+ return {"results": results}
 
82
 
83
  if __name__ == "__main__":
84
+ import uvicorn
85
+ uvicorn.run(app, host="0.0.0.0", port=8082, reload=True)