saumitras commited on
Commit
f3d315e
1 Parent(s): b513aa0
Files changed (7) hide show
  1. app.py +130 -0
  2. colpali_manager.py +97 -0
  3. middleware.py +56 -0
  4. milvus_manager.py +162 -0
  5. packages.txt +1 -0
  6. pdf_manager.py +42 -0
  7. requirements.txt +8 -0
app.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import tempfile
3
+ import os
4
+ import fitz # PyMuPDF
5
+ import uuid
6
+
7
+
8
+ from middleware import Middleware
9
+
10
+ def generate_uuid(state):
11
+ # Check if UUID already exists in session state
12
+ if state["user_uuid"] is None:
13
+ # Generate a new UUID if not already set
14
+ state["user_uuid"] = str(uuid.uuid4())
15
+
16
+ return state["user_uuid"]
17
+
18
+
19
+ class PDFSearchApp:
20
+ def __init__(self):
21
+ self.indexed_docs = {}
22
+ self.current_pdf = None
23
+
24
+
25
+ def upload_and_convert(self, state, file, max_pages):
26
+ id = generate_uuid(state)
27
+
28
+ if file is None:
29
+ return "No file uploaded"
30
+
31
+ print(f"Uploading file: {file.name}, id: {id}")
32
+
33
+ try:
34
+ self.current_pdf = file.name
35
+
36
+ middleware = Middleware(id, create_collection=True)
37
+
38
+ pages = middleware.index(pdf_path=file.name, id=id, max_pages=max_pages)
39
+
40
+ self.indexed_docs[id] = True
41
+
42
+ return f"Uploaded and extracted {len(pages)} pages"
43
+ except Exception as e:
44
+ return f"Error processing PDF: {str(e)}"
45
+
46
+
47
+ def search_documents(self, state, query, num_results=5):
48
+ print(f"Searching for query: {query}")
49
+ id = generate_uuid(state)
50
+
51
+ if not self.indexed_docs[id]:
52
+ print("Please index documents first")
53
+ return "Please index documents first"
54
+ if not query:
55
+ print("Please enter a search query")
56
+ return "Please enter a search query"
57
+
58
+ try:
59
+
60
+ middleware = Middleware(id, create_collection=False)
61
+
62
+ search_results = middleware.search([query])[0]
63
+
64
+ page_num = search_results[0][1] + 1
65
+
66
+ print(f"Retrieved page number: {page_num}")
67
+
68
+ img_path = f"pages/{id}/page_{page_num}.png"
69
+
70
+ print(f"Retrieved image path: {img_path}")
71
+
72
+ return img_path
73
+
74
+ except Exception as e:
75
+ return f"Error during search: {str(e)}"
76
+
77
+ def create_ui():
78
+ app = PDFSearchApp()
79
+
80
+ with gr.Blocks() as demo:
81
+ state = gr.State(value={"user_uuid": None})
82
+
83
+ gr.Markdown("# Colpali Milvus Search Demo")
84
+ gr.Markdown("This demo showcases how to use [Colpali](https://github.com/illuin-tech/colpali) embeddings with [Milvus](https://milvus.io/) for pdf search.")
85
+
86
+ with gr.Tab("Upload PDFs"):
87
+ with gr.Column():
88
+ file_input = gr.File(label="Upload PDFs")
89
+
90
+ max_pages_input = gr.Slider(
91
+ minimum=1,
92
+ maximum=2000,
93
+ value=10,
94
+ step=10,
95
+ label="Max Pages"
96
+ )
97
+
98
+ status = gr.Textbox(label="Status", interactive=False)
99
+
100
+ with gr.Tab("Search"):
101
+ with gr.Column():
102
+ query_input = gr.Textbox(label="Query")
103
+ num_results = gr.Slider(
104
+ minimum=1,
105
+ maximum=10,
106
+ value=5,
107
+ step=1,
108
+ label="Number of results"
109
+ )
110
+ search_btn = gr.Button("Search")
111
+ results = gr.Image(label="Retrieved Documents")
112
+
113
+ # Event handlers
114
+ file_input.change(
115
+ fn=app.upload_and_convert,
116
+ inputs=[state, file_input, max_pages_input],
117
+ outputs=[status]
118
+ )
119
+
120
+ search_btn.click(
121
+ fn=app.search_documents,
122
+ inputs=[state, query_input, num_results],
123
+ outputs=[results]
124
+ )
125
+
126
+ return demo
127
+
128
+ if __name__ == "__main__":
129
+ demo = create_ui()
130
+ demo.launch()
colpali_manager.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from colpali_engine.models import ColPali
2
+ from colpali_engine.models.paligemma.colpali.processing_colpali import ColPaliProcessor
3
+ from colpali_engine.utils.processing_utils import BaseVisualRetrieverProcessor
4
+ from colpali_engine.utils.torch_utils import ListDataset, get_torch_device
5
+ from torch.utils.data import DataLoader
6
+ import torch
7
+ from typing import List, cast
8
+
9
+ from tqdm import tqdm
10
+ from PIL import Image
11
+ import os
12
+
13
+ import spaces
14
+
15
+ model_name = "vidore/colpali-v1.2"
16
+ device = get_torch_device("cuda")
17
+
18
+ model = ColPali.from_pretrained(
19
+ model_name,
20
+ torch_dtype=torch.bfloat16,
21
+ device_map=device,
22
+ ).eval()
23
+
24
+ processor = cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name))
25
+
26
+ class ColpaliManager:
27
+
28
+
29
+ def __init__(self, device = "cuda", model_name = "vidore/colpali-v1.2"):
30
+
31
+ print(f"Initializing ColpaliManager with device {device} and model {model_name}")
32
+
33
+ # self.device = get_torch_device(device)
34
+
35
+ # self.model = ColPali.from_pretrained(
36
+ # model_name,
37
+ # torch_dtype=torch.bfloat16,
38
+ # device_map=self.device,
39
+ # ).eval()
40
+
41
+ # self.processor = cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name))
42
+
43
+ @spaces.GPU
44
+ def get_images(self, paths: list[str]) -> List[Image.Image]:
45
+ return [Image.open(path) for path in paths]
46
+
47
+ @spaces.GPU
48
+ def process_images(self, image_paths:list[str], batch_size=5):
49
+
50
+ print(f"Processing {len(image_paths)} image_paths")
51
+
52
+ images = self.get_images(image_paths)
53
+
54
+ dataloader = DataLoader(
55
+ dataset=ListDataset[str](images),
56
+ batch_size=batch_size,
57
+ shuffle=False,
58
+ collate_fn=lambda x: processor.process_images(x),
59
+ )
60
+
61
+ ds: List[torch.Tensor] = []
62
+ for batch_doc in tqdm(dataloader):
63
+ with torch.no_grad():
64
+ batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()}
65
+ embeddings_doc = model(**batch_doc)
66
+ ds.extend(list(torch.unbind(embeddings_doc.to(device))))
67
+
68
+ ds_np = [d.float().cpu().numpy() for d in ds]
69
+
70
+ return ds_np
71
+
72
+
73
+ @spaces.GPU
74
+ def process_text(self, texts: list[str]):
75
+ print(f"Processing {len(texts)} texts")
76
+
77
+ dataloader = DataLoader(
78
+ dataset=ListDataset[str](texts),
79
+ batch_size=1,
80
+ shuffle=False,
81
+ collate_fn=lambda x: processor.process_queries(x),
82
+ )
83
+
84
+ qs: List[torch.Tensor] = []
85
+ for batch_query in dataloader:
86
+ with torch.no_grad():
87
+ batch_query = {k: v.to(model.device) for k, v in batch_query.items()}
88
+ embeddings_query = model(**batch_query)
89
+
90
+ qs.extend(list(torch.unbind(embeddings_query.to(device))))
91
+
92
+ qs_np = [q.float().cpu().numpy() for q in qs]
93
+
94
+ return qs_np
95
+
96
+
97
+
middleware.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from colpali_manager import ColpaliManager
2
+ from milvus_manager import MilvusManager
3
+ from pdf_manager import PdfManager
4
+ import hashlib
5
+
6
+
7
+ pdf_manager = PdfManager()
8
+ colpali_manager = ColpaliManager()
9
+
10
+
11
+
12
+ class Middleware:
13
+ def __init__(self, id:str, create_collection=True):
14
+ hashed_id = hashlib.md5(id.encode()).hexdigest()[:8]
15
+ milvus_db_name = f"milvus_{hashed_id}.db"
16
+ self.milvus_manager = MilvusManager(milvus_db_name, "colpali", create_collection)
17
+
18
+ def index(self, pdf_path: str, id:str, max_pages: int, pages: list[int] = None):
19
+
20
+ print(f"Indexing {pdf_path}, id: {id}, max_pages: {max_pages}")
21
+
22
+ image_paths = pdf_manager.save_images(id, pdf_path, max_pages)
23
+
24
+ print(f"Saved {len(image_paths)} images")
25
+
26
+ colbert_vecs = colpali_manager.process_images(image_paths)
27
+
28
+ images_data = [{
29
+ "colbert_vecs": colbert_vecs[i],
30
+ "filepath": image_paths[i]
31
+ } for i in range(len(image_paths))]
32
+
33
+ print(f"Inserting {len(images_data)} images data to Milvus")
34
+
35
+ self.milvus_manager.insert_images_data(images_data)
36
+
37
+ print("Indexing completed")
38
+
39
+ return image_paths
40
+
41
+
42
+
43
+ def search(self, search_queries: list[str]):
44
+ print(f"Searching for {len(search_queries)} queries")
45
+
46
+ final_res = []
47
+
48
+ for query in search_queries:
49
+ print(f"Searching for query: {query}")
50
+ query_vec = colpali_manager.process_text([query])[0]
51
+ search_res = self.milvus_manager.search(query_vec, topk=1)
52
+ print(f"Search result: {search_res} for query: {query}")
53
+ final_res.append(search_res)
54
+
55
+ return final_res
56
+
milvus_manager.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pymilvus import MilvusClient, DataType
2
+ import numpy as np
3
+ import concurrent.futures
4
+
5
+
6
+ class MilvusManager:
7
+ def __init__(self, milvus_uri, collection_name, create_collection, dim=128):
8
+ self.client = MilvusClient(uri=milvus_uri)
9
+ self.collection_name = collection_name
10
+ if self.client.has_collection(collection_name=self.collection_name):
11
+ self.client.load_collection(collection_name)
12
+ self.dim = dim
13
+
14
+ if create_collection:
15
+ self.create_collection()
16
+ self.create_index()
17
+
18
+
19
+ def create_collection(self):
20
+ if self.client.has_collection(collection_name=self.collection_name):
21
+ self.client.drop_collection(collection_name=self.collection_name)
22
+ schema = self.client.create_schema(
23
+ auto_id=True,
24
+ enable_dynamic_fields=True,
25
+ )
26
+ schema.add_field(field_name="pk", datatype=DataType.INT64, is_primary=True)
27
+ schema.add_field(
28
+ field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=self.dim
29
+ )
30
+ schema.add_field(field_name="seq_id", datatype=DataType.INT16)
31
+ schema.add_field(field_name="doc_id", datatype=DataType.INT64)
32
+ schema.add_field(field_name="doc", datatype=DataType.VARCHAR, max_length=65535)
33
+
34
+ self.client.create_collection(
35
+ collection_name=self.collection_name, schema=schema
36
+ )
37
+
38
+ def create_index(self):
39
+ self.client.release_collection(collection_name=self.collection_name)
40
+ self.client.drop_index(
41
+ collection_name=self.collection_name, index_name="vector"
42
+ )
43
+ index_params = self.client.prepare_index_params()
44
+ index_params.add_index(
45
+ field_name="vector",
46
+ index_name="vector_index",
47
+ index_type="HNSW",
48
+ metric_type="IP",
49
+ params={
50
+ "M": 16,
51
+ "efConstruction": 500,
52
+ },
53
+ )
54
+
55
+ self.client.create_index(
56
+ collection_name=self.collection_name, index_params=index_params, sync=True
57
+ )
58
+
59
+ def create_scalar_index(self):
60
+ self.client.release_collection(collection_name=self.collection_name)
61
+
62
+ index_params = self.client.prepare_index_params()
63
+ index_params.add_index(
64
+ field_name="doc_id",
65
+ index_name="int32_index",
66
+ index_type="INVERTED",
67
+ )
68
+
69
+ self.client.create_index(
70
+ collection_name=self.collection_name, index_params=index_params, sync=True
71
+ )
72
+
73
+ def search(self, data, topk):
74
+ search_params = {"metric_type": "IP", "params": {}}
75
+ results = self.client.search(
76
+ self.collection_name,
77
+ data,
78
+ limit=int(50),
79
+ output_fields=["vector", "seq_id", "doc_id"],
80
+ search_params=search_params,
81
+ )
82
+ doc_ids = set()
83
+ for r_id in range(len(results)):
84
+ for r in range(len(results[r_id])):
85
+ doc_ids.add(results[r_id][r]["entity"]["doc_id"])
86
+
87
+ scores = []
88
+
89
+ def rerank_single_doc(doc_id, data, client, collection_name):
90
+ doc_colbert_vecs = client.query(
91
+ collection_name=collection_name,
92
+ filter=f"doc_id in [{doc_id}, {doc_id + 1}]",
93
+ output_fields=["seq_id", "vector", "doc"],
94
+ limit=1000,
95
+ )
96
+ doc_vecs = np.vstack(
97
+ [doc_colbert_vecs[i]["vector"] for i in range(len(doc_colbert_vecs))]
98
+ )
99
+ score = np.dot(data, doc_vecs.T).max(1).sum()
100
+ return (score, doc_id)
101
+
102
+ with concurrent.futures.ThreadPoolExecutor(max_workers=300) as executor:
103
+ futures = {
104
+ executor.submit(
105
+ rerank_single_doc, doc_id, data, self.client, self.collection_name
106
+ ): doc_id
107
+ for doc_id in doc_ids
108
+ }
109
+ for future in concurrent.futures.as_completed(futures):
110
+ score, doc_id = future.result()
111
+ scores.append((score, doc_id))
112
+
113
+ scores.sort(key=lambda x: x[0], reverse=True)
114
+ if len(scores) >= topk:
115
+ return scores[:topk]
116
+ else:
117
+ return scores
118
+
119
+ def insert(self, data):
120
+ colbert_vecs = [vec for vec in data["colbert_vecs"]]
121
+ seq_length = len(colbert_vecs)
122
+ doc_ids = [data["doc_id"] for i in range(seq_length)]
123
+ seq_ids = list(range(seq_length))
124
+ docs = [""] * seq_length
125
+ docs[0] = data["filepath"]
126
+
127
+ self.client.insert(
128
+ self.collection_name,
129
+ [
130
+ {
131
+ "vector": colbert_vecs[i],
132
+ "seq_id": seq_ids[i],
133
+ "doc_id": doc_ids[i],
134
+ "doc": docs[i],
135
+ }
136
+ for i in range(seq_length)
137
+ ],
138
+ )
139
+
140
+
141
+ def get_images_as_doc(self, images_with_vectors:list):
142
+
143
+ images_data = []
144
+
145
+ for i in range(len(images_with_vectors)):
146
+ data = {
147
+ "colbert_vecs": images_with_vectors[i]["colbert_vecs"],
148
+ "doc_id": i,
149
+ "filepath": images_with_vectors[i]["filepath"],
150
+ }
151
+ images_data.append(data)
152
+
153
+ return images_data
154
+
155
+
156
+ def insert_images_data(self, image_data):
157
+ data = self.get_images_as_doc(image_data)
158
+
159
+ for i in range(len(data)):
160
+ self.insert(data[i])
161
+
162
+
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ poppler-utils
pdf_manager.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pdf2image import convert_from_path
2
+ import os
3
+ import shutil
4
+
5
+ class PdfManager:
6
+ def __init__(self):
7
+ pass
8
+
9
+ def clear_and_recreate_dir(self, output_folder):
10
+ print(f"Clearing output folder {output_folder}")
11
+
12
+ if os.path.exists(output_folder):
13
+ shutil.rmtree(output_folder)
14
+
15
+ os.makedirs(output_folder)
16
+
17
+ def save_images(self, id, pdf_path, max_pages, pages: list[int] = None) -> list[str]:
18
+ output_folder = f"pages/{id}/"
19
+ images = convert_from_path(pdf_path)
20
+
21
+ print(f"Saving images from {pdf_path} to {output_folder}. Max pages: {max_pages}")
22
+
23
+ self.clear_and_recreate_dir(output_folder)
24
+
25
+ num_page_processed = 0
26
+
27
+ for i, image in enumerate(images):
28
+ if max_pages and num_page_processed >= max_pages:
29
+ break
30
+
31
+ if pages and i not in pages:
32
+ continue
33
+
34
+ full_save_path = f"{output_folder}/page_{i + 1}.png"
35
+
36
+ #print(f"Saving image to {full_save_path}")
37
+
38
+ image.save(full_save_path, "PNG")
39
+
40
+ num_page_processed += 1
41
+
42
+ return [f"{output_folder}/page_{i + 1}.png" for i in range(num_page_processed)]
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ gradio==4.25.0
2
+ PyMuPDF==1.24.9
3
+ pdf2image==1.17.0
4
+ pymilvus==2.4.9
5
+ colpali_engine==0.3.4
6
+ tqdm==4.66.5
7
+ pillow==10.4.0
8
+ spaces==0.30.4