Spaces:
Running
on
Zero
Running
on
Zero
init
Browse files- app.py +130 -0
- colpali_manager.py +97 -0
- middleware.py +56 -0
- milvus_manager.py +162 -0
- packages.txt +1 -0
- pdf_manager.py +42 -0
- requirements.txt +8 -0
app.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import tempfile
|
3 |
+
import os
|
4 |
+
import fitz # PyMuPDF
|
5 |
+
import uuid
|
6 |
+
|
7 |
+
|
8 |
+
from middleware import Middleware
|
9 |
+
|
10 |
+
def generate_uuid(state):
|
11 |
+
# Check if UUID already exists in session state
|
12 |
+
if state["user_uuid"] is None:
|
13 |
+
# Generate a new UUID if not already set
|
14 |
+
state["user_uuid"] = str(uuid.uuid4())
|
15 |
+
|
16 |
+
return state["user_uuid"]
|
17 |
+
|
18 |
+
|
19 |
+
class PDFSearchApp:
|
20 |
+
def __init__(self):
|
21 |
+
self.indexed_docs = {}
|
22 |
+
self.current_pdf = None
|
23 |
+
|
24 |
+
|
25 |
+
def upload_and_convert(self, state, file, max_pages):
|
26 |
+
id = generate_uuid(state)
|
27 |
+
|
28 |
+
if file is None:
|
29 |
+
return "No file uploaded"
|
30 |
+
|
31 |
+
print(f"Uploading file: {file.name}, id: {id}")
|
32 |
+
|
33 |
+
try:
|
34 |
+
self.current_pdf = file.name
|
35 |
+
|
36 |
+
middleware = Middleware(id, create_collection=True)
|
37 |
+
|
38 |
+
pages = middleware.index(pdf_path=file.name, id=id, max_pages=max_pages)
|
39 |
+
|
40 |
+
self.indexed_docs[id] = True
|
41 |
+
|
42 |
+
return f"Uploaded and extracted {len(pages)} pages"
|
43 |
+
except Exception as e:
|
44 |
+
return f"Error processing PDF: {str(e)}"
|
45 |
+
|
46 |
+
|
47 |
+
def search_documents(self, state, query, num_results=5):
|
48 |
+
print(f"Searching for query: {query}")
|
49 |
+
id = generate_uuid(state)
|
50 |
+
|
51 |
+
if not self.indexed_docs[id]:
|
52 |
+
print("Please index documents first")
|
53 |
+
return "Please index documents first"
|
54 |
+
if not query:
|
55 |
+
print("Please enter a search query")
|
56 |
+
return "Please enter a search query"
|
57 |
+
|
58 |
+
try:
|
59 |
+
|
60 |
+
middleware = Middleware(id, create_collection=False)
|
61 |
+
|
62 |
+
search_results = middleware.search([query])[0]
|
63 |
+
|
64 |
+
page_num = search_results[0][1] + 1
|
65 |
+
|
66 |
+
print(f"Retrieved page number: {page_num}")
|
67 |
+
|
68 |
+
img_path = f"pages/{id}/page_{page_num}.png"
|
69 |
+
|
70 |
+
print(f"Retrieved image path: {img_path}")
|
71 |
+
|
72 |
+
return img_path
|
73 |
+
|
74 |
+
except Exception as e:
|
75 |
+
return f"Error during search: {str(e)}"
|
76 |
+
|
77 |
+
def create_ui():
|
78 |
+
app = PDFSearchApp()
|
79 |
+
|
80 |
+
with gr.Blocks() as demo:
|
81 |
+
state = gr.State(value={"user_uuid": None})
|
82 |
+
|
83 |
+
gr.Markdown("# Colpali Milvus Search Demo")
|
84 |
+
gr.Markdown("This demo showcases how to use [Colpali](https://github.com/illuin-tech/colpali) embeddings with [Milvus](https://milvus.io/) for pdf search.")
|
85 |
+
|
86 |
+
with gr.Tab("Upload PDFs"):
|
87 |
+
with gr.Column():
|
88 |
+
file_input = gr.File(label="Upload PDFs")
|
89 |
+
|
90 |
+
max_pages_input = gr.Slider(
|
91 |
+
minimum=1,
|
92 |
+
maximum=2000,
|
93 |
+
value=10,
|
94 |
+
step=10,
|
95 |
+
label="Max Pages"
|
96 |
+
)
|
97 |
+
|
98 |
+
status = gr.Textbox(label="Status", interactive=False)
|
99 |
+
|
100 |
+
with gr.Tab("Search"):
|
101 |
+
with gr.Column():
|
102 |
+
query_input = gr.Textbox(label="Query")
|
103 |
+
num_results = gr.Slider(
|
104 |
+
minimum=1,
|
105 |
+
maximum=10,
|
106 |
+
value=5,
|
107 |
+
step=1,
|
108 |
+
label="Number of results"
|
109 |
+
)
|
110 |
+
search_btn = gr.Button("Search")
|
111 |
+
results = gr.Image(label="Retrieved Documents")
|
112 |
+
|
113 |
+
# Event handlers
|
114 |
+
file_input.change(
|
115 |
+
fn=app.upload_and_convert,
|
116 |
+
inputs=[state, file_input, max_pages_input],
|
117 |
+
outputs=[status]
|
118 |
+
)
|
119 |
+
|
120 |
+
search_btn.click(
|
121 |
+
fn=app.search_documents,
|
122 |
+
inputs=[state, query_input, num_results],
|
123 |
+
outputs=[results]
|
124 |
+
)
|
125 |
+
|
126 |
+
return demo
|
127 |
+
|
128 |
+
if __name__ == "__main__":
|
129 |
+
demo = create_ui()
|
130 |
+
demo.launch()
|
colpali_manager.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from colpali_engine.models import ColPali
|
2 |
+
from colpali_engine.models.paligemma.colpali.processing_colpali import ColPaliProcessor
|
3 |
+
from colpali_engine.utils.processing_utils import BaseVisualRetrieverProcessor
|
4 |
+
from colpali_engine.utils.torch_utils import ListDataset, get_torch_device
|
5 |
+
from torch.utils.data import DataLoader
|
6 |
+
import torch
|
7 |
+
from typing import List, cast
|
8 |
+
|
9 |
+
from tqdm import tqdm
|
10 |
+
from PIL import Image
|
11 |
+
import os
|
12 |
+
|
13 |
+
import spaces
|
14 |
+
|
15 |
+
model_name = "vidore/colpali-v1.2"
|
16 |
+
device = get_torch_device("cuda")
|
17 |
+
|
18 |
+
model = ColPali.from_pretrained(
|
19 |
+
model_name,
|
20 |
+
torch_dtype=torch.bfloat16,
|
21 |
+
device_map=device,
|
22 |
+
).eval()
|
23 |
+
|
24 |
+
processor = cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name))
|
25 |
+
|
26 |
+
class ColpaliManager:
|
27 |
+
|
28 |
+
|
29 |
+
def __init__(self, device = "cuda", model_name = "vidore/colpali-v1.2"):
|
30 |
+
|
31 |
+
print(f"Initializing ColpaliManager with device {device} and model {model_name}")
|
32 |
+
|
33 |
+
# self.device = get_torch_device(device)
|
34 |
+
|
35 |
+
# self.model = ColPali.from_pretrained(
|
36 |
+
# model_name,
|
37 |
+
# torch_dtype=torch.bfloat16,
|
38 |
+
# device_map=self.device,
|
39 |
+
# ).eval()
|
40 |
+
|
41 |
+
# self.processor = cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name))
|
42 |
+
|
43 |
+
@spaces.GPU
|
44 |
+
def get_images(self, paths: list[str]) -> List[Image.Image]:
|
45 |
+
return [Image.open(path) for path in paths]
|
46 |
+
|
47 |
+
@spaces.GPU
|
48 |
+
def process_images(self, image_paths:list[str], batch_size=5):
|
49 |
+
|
50 |
+
print(f"Processing {len(image_paths)} image_paths")
|
51 |
+
|
52 |
+
images = self.get_images(image_paths)
|
53 |
+
|
54 |
+
dataloader = DataLoader(
|
55 |
+
dataset=ListDataset[str](images),
|
56 |
+
batch_size=batch_size,
|
57 |
+
shuffle=False,
|
58 |
+
collate_fn=lambda x: processor.process_images(x),
|
59 |
+
)
|
60 |
+
|
61 |
+
ds: List[torch.Tensor] = []
|
62 |
+
for batch_doc in tqdm(dataloader):
|
63 |
+
with torch.no_grad():
|
64 |
+
batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()}
|
65 |
+
embeddings_doc = model(**batch_doc)
|
66 |
+
ds.extend(list(torch.unbind(embeddings_doc.to(device))))
|
67 |
+
|
68 |
+
ds_np = [d.float().cpu().numpy() for d in ds]
|
69 |
+
|
70 |
+
return ds_np
|
71 |
+
|
72 |
+
|
73 |
+
@spaces.GPU
|
74 |
+
def process_text(self, texts: list[str]):
|
75 |
+
print(f"Processing {len(texts)} texts")
|
76 |
+
|
77 |
+
dataloader = DataLoader(
|
78 |
+
dataset=ListDataset[str](texts),
|
79 |
+
batch_size=1,
|
80 |
+
shuffle=False,
|
81 |
+
collate_fn=lambda x: processor.process_queries(x),
|
82 |
+
)
|
83 |
+
|
84 |
+
qs: List[torch.Tensor] = []
|
85 |
+
for batch_query in dataloader:
|
86 |
+
with torch.no_grad():
|
87 |
+
batch_query = {k: v.to(model.device) for k, v in batch_query.items()}
|
88 |
+
embeddings_query = model(**batch_query)
|
89 |
+
|
90 |
+
qs.extend(list(torch.unbind(embeddings_query.to(device))))
|
91 |
+
|
92 |
+
qs_np = [q.float().cpu().numpy() for q in qs]
|
93 |
+
|
94 |
+
return qs_np
|
95 |
+
|
96 |
+
|
97 |
+
|
middleware.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from colpali_manager import ColpaliManager
|
2 |
+
from milvus_manager import MilvusManager
|
3 |
+
from pdf_manager import PdfManager
|
4 |
+
import hashlib
|
5 |
+
|
6 |
+
|
7 |
+
pdf_manager = PdfManager()
|
8 |
+
colpali_manager = ColpaliManager()
|
9 |
+
|
10 |
+
|
11 |
+
|
12 |
+
class Middleware:
|
13 |
+
def __init__(self, id:str, create_collection=True):
|
14 |
+
hashed_id = hashlib.md5(id.encode()).hexdigest()[:8]
|
15 |
+
milvus_db_name = f"milvus_{hashed_id}.db"
|
16 |
+
self.milvus_manager = MilvusManager(milvus_db_name, "colpali", create_collection)
|
17 |
+
|
18 |
+
def index(self, pdf_path: str, id:str, max_pages: int, pages: list[int] = None):
|
19 |
+
|
20 |
+
print(f"Indexing {pdf_path}, id: {id}, max_pages: {max_pages}")
|
21 |
+
|
22 |
+
image_paths = pdf_manager.save_images(id, pdf_path, max_pages)
|
23 |
+
|
24 |
+
print(f"Saved {len(image_paths)} images")
|
25 |
+
|
26 |
+
colbert_vecs = colpali_manager.process_images(image_paths)
|
27 |
+
|
28 |
+
images_data = [{
|
29 |
+
"colbert_vecs": colbert_vecs[i],
|
30 |
+
"filepath": image_paths[i]
|
31 |
+
} for i in range(len(image_paths))]
|
32 |
+
|
33 |
+
print(f"Inserting {len(images_data)} images data to Milvus")
|
34 |
+
|
35 |
+
self.milvus_manager.insert_images_data(images_data)
|
36 |
+
|
37 |
+
print("Indexing completed")
|
38 |
+
|
39 |
+
return image_paths
|
40 |
+
|
41 |
+
|
42 |
+
|
43 |
+
def search(self, search_queries: list[str]):
|
44 |
+
print(f"Searching for {len(search_queries)} queries")
|
45 |
+
|
46 |
+
final_res = []
|
47 |
+
|
48 |
+
for query in search_queries:
|
49 |
+
print(f"Searching for query: {query}")
|
50 |
+
query_vec = colpali_manager.process_text([query])[0]
|
51 |
+
search_res = self.milvus_manager.search(query_vec, topk=1)
|
52 |
+
print(f"Search result: {search_res} for query: {query}")
|
53 |
+
final_res.append(search_res)
|
54 |
+
|
55 |
+
return final_res
|
56 |
+
|
milvus_manager.py
ADDED
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pymilvus import MilvusClient, DataType
|
2 |
+
import numpy as np
|
3 |
+
import concurrent.futures
|
4 |
+
|
5 |
+
|
6 |
+
class MilvusManager:
|
7 |
+
def __init__(self, milvus_uri, collection_name, create_collection, dim=128):
|
8 |
+
self.client = MilvusClient(uri=milvus_uri)
|
9 |
+
self.collection_name = collection_name
|
10 |
+
if self.client.has_collection(collection_name=self.collection_name):
|
11 |
+
self.client.load_collection(collection_name)
|
12 |
+
self.dim = dim
|
13 |
+
|
14 |
+
if create_collection:
|
15 |
+
self.create_collection()
|
16 |
+
self.create_index()
|
17 |
+
|
18 |
+
|
19 |
+
def create_collection(self):
|
20 |
+
if self.client.has_collection(collection_name=self.collection_name):
|
21 |
+
self.client.drop_collection(collection_name=self.collection_name)
|
22 |
+
schema = self.client.create_schema(
|
23 |
+
auto_id=True,
|
24 |
+
enable_dynamic_fields=True,
|
25 |
+
)
|
26 |
+
schema.add_field(field_name="pk", datatype=DataType.INT64, is_primary=True)
|
27 |
+
schema.add_field(
|
28 |
+
field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=self.dim
|
29 |
+
)
|
30 |
+
schema.add_field(field_name="seq_id", datatype=DataType.INT16)
|
31 |
+
schema.add_field(field_name="doc_id", datatype=DataType.INT64)
|
32 |
+
schema.add_field(field_name="doc", datatype=DataType.VARCHAR, max_length=65535)
|
33 |
+
|
34 |
+
self.client.create_collection(
|
35 |
+
collection_name=self.collection_name, schema=schema
|
36 |
+
)
|
37 |
+
|
38 |
+
def create_index(self):
|
39 |
+
self.client.release_collection(collection_name=self.collection_name)
|
40 |
+
self.client.drop_index(
|
41 |
+
collection_name=self.collection_name, index_name="vector"
|
42 |
+
)
|
43 |
+
index_params = self.client.prepare_index_params()
|
44 |
+
index_params.add_index(
|
45 |
+
field_name="vector",
|
46 |
+
index_name="vector_index",
|
47 |
+
index_type="HNSW",
|
48 |
+
metric_type="IP",
|
49 |
+
params={
|
50 |
+
"M": 16,
|
51 |
+
"efConstruction": 500,
|
52 |
+
},
|
53 |
+
)
|
54 |
+
|
55 |
+
self.client.create_index(
|
56 |
+
collection_name=self.collection_name, index_params=index_params, sync=True
|
57 |
+
)
|
58 |
+
|
59 |
+
def create_scalar_index(self):
|
60 |
+
self.client.release_collection(collection_name=self.collection_name)
|
61 |
+
|
62 |
+
index_params = self.client.prepare_index_params()
|
63 |
+
index_params.add_index(
|
64 |
+
field_name="doc_id",
|
65 |
+
index_name="int32_index",
|
66 |
+
index_type="INVERTED",
|
67 |
+
)
|
68 |
+
|
69 |
+
self.client.create_index(
|
70 |
+
collection_name=self.collection_name, index_params=index_params, sync=True
|
71 |
+
)
|
72 |
+
|
73 |
+
def search(self, data, topk):
|
74 |
+
search_params = {"metric_type": "IP", "params": {}}
|
75 |
+
results = self.client.search(
|
76 |
+
self.collection_name,
|
77 |
+
data,
|
78 |
+
limit=int(50),
|
79 |
+
output_fields=["vector", "seq_id", "doc_id"],
|
80 |
+
search_params=search_params,
|
81 |
+
)
|
82 |
+
doc_ids = set()
|
83 |
+
for r_id in range(len(results)):
|
84 |
+
for r in range(len(results[r_id])):
|
85 |
+
doc_ids.add(results[r_id][r]["entity"]["doc_id"])
|
86 |
+
|
87 |
+
scores = []
|
88 |
+
|
89 |
+
def rerank_single_doc(doc_id, data, client, collection_name):
|
90 |
+
doc_colbert_vecs = client.query(
|
91 |
+
collection_name=collection_name,
|
92 |
+
filter=f"doc_id in [{doc_id}, {doc_id + 1}]",
|
93 |
+
output_fields=["seq_id", "vector", "doc"],
|
94 |
+
limit=1000,
|
95 |
+
)
|
96 |
+
doc_vecs = np.vstack(
|
97 |
+
[doc_colbert_vecs[i]["vector"] for i in range(len(doc_colbert_vecs))]
|
98 |
+
)
|
99 |
+
score = np.dot(data, doc_vecs.T).max(1).sum()
|
100 |
+
return (score, doc_id)
|
101 |
+
|
102 |
+
with concurrent.futures.ThreadPoolExecutor(max_workers=300) as executor:
|
103 |
+
futures = {
|
104 |
+
executor.submit(
|
105 |
+
rerank_single_doc, doc_id, data, self.client, self.collection_name
|
106 |
+
): doc_id
|
107 |
+
for doc_id in doc_ids
|
108 |
+
}
|
109 |
+
for future in concurrent.futures.as_completed(futures):
|
110 |
+
score, doc_id = future.result()
|
111 |
+
scores.append((score, doc_id))
|
112 |
+
|
113 |
+
scores.sort(key=lambda x: x[0], reverse=True)
|
114 |
+
if len(scores) >= topk:
|
115 |
+
return scores[:topk]
|
116 |
+
else:
|
117 |
+
return scores
|
118 |
+
|
119 |
+
def insert(self, data):
|
120 |
+
colbert_vecs = [vec for vec in data["colbert_vecs"]]
|
121 |
+
seq_length = len(colbert_vecs)
|
122 |
+
doc_ids = [data["doc_id"] for i in range(seq_length)]
|
123 |
+
seq_ids = list(range(seq_length))
|
124 |
+
docs = [""] * seq_length
|
125 |
+
docs[0] = data["filepath"]
|
126 |
+
|
127 |
+
self.client.insert(
|
128 |
+
self.collection_name,
|
129 |
+
[
|
130 |
+
{
|
131 |
+
"vector": colbert_vecs[i],
|
132 |
+
"seq_id": seq_ids[i],
|
133 |
+
"doc_id": doc_ids[i],
|
134 |
+
"doc": docs[i],
|
135 |
+
}
|
136 |
+
for i in range(seq_length)
|
137 |
+
],
|
138 |
+
)
|
139 |
+
|
140 |
+
|
141 |
+
def get_images_as_doc(self, images_with_vectors:list):
|
142 |
+
|
143 |
+
images_data = []
|
144 |
+
|
145 |
+
for i in range(len(images_with_vectors)):
|
146 |
+
data = {
|
147 |
+
"colbert_vecs": images_with_vectors[i]["colbert_vecs"],
|
148 |
+
"doc_id": i,
|
149 |
+
"filepath": images_with_vectors[i]["filepath"],
|
150 |
+
}
|
151 |
+
images_data.append(data)
|
152 |
+
|
153 |
+
return images_data
|
154 |
+
|
155 |
+
|
156 |
+
def insert_images_data(self, image_data):
|
157 |
+
data = self.get_images_as_doc(image_data)
|
158 |
+
|
159 |
+
for i in range(len(data)):
|
160 |
+
self.insert(data[i])
|
161 |
+
|
162 |
+
|
packages.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
poppler-utils
|
pdf_manager.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pdf2image import convert_from_path
|
2 |
+
import os
|
3 |
+
import shutil
|
4 |
+
|
5 |
+
class PdfManager:
|
6 |
+
def __init__(self):
|
7 |
+
pass
|
8 |
+
|
9 |
+
def clear_and_recreate_dir(self, output_folder):
|
10 |
+
print(f"Clearing output folder {output_folder}")
|
11 |
+
|
12 |
+
if os.path.exists(output_folder):
|
13 |
+
shutil.rmtree(output_folder)
|
14 |
+
|
15 |
+
os.makedirs(output_folder)
|
16 |
+
|
17 |
+
def save_images(self, id, pdf_path, max_pages, pages: list[int] = None) -> list[str]:
|
18 |
+
output_folder = f"pages/{id}/"
|
19 |
+
images = convert_from_path(pdf_path)
|
20 |
+
|
21 |
+
print(f"Saving images from {pdf_path} to {output_folder}. Max pages: {max_pages}")
|
22 |
+
|
23 |
+
self.clear_and_recreate_dir(output_folder)
|
24 |
+
|
25 |
+
num_page_processed = 0
|
26 |
+
|
27 |
+
for i, image in enumerate(images):
|
28 |
+
if max_pages and num_page_processed >= max_pages:
|
29 |
+
break
|
30 |
+
|
31 |
+
if pages and i not in pages:
|
32 |
+
continue
|
33 |
+
|
34 |
+
full_save_path = f"{output_folder}/page_{i + 1}.png"
|
35 |
+
|
36 |
+
#print(f"Saving image to {full_save_path}")
|
37 |
+
|
38 |
+
image.save(full_save_path, "PNG")
|
39 |
+
|
40 |
+
num_page_processed += 1
|
41 |
+
|
42 |
+
return [f"{output_folder}/page_{i + 1}.png" for i in range(num_page_processed)]
|
requirements.txt
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio==4.25.0
|
2 |
+
PyMuPDF==1.24.9
|
3 |
+
pdf2image==1.17.0
|
4 |
+
pymilvus==2.4.9
|
5 |
+
colpali_engine==0.3.4
|
6 |
+
tqdm==4.66.5
|
7 |
+
pillow==10.4.0
|
8 |
+
spaces==0.30.4
|