File size: 4,423 Bytes
33de980
4390904
 
 
33de980
814b6ba
 
33de980
 
 
 
 
 
 
2902a60
 
d005da4
 
 
 
 
 
33de980
 
4390904
d005da4
 
 
 
 
4390904
d005da4
 
 
 
 
 
 
33de980
 
 
d005da4
33de980
 
 
4390904
33de980
 
 
 
 
 
4390904
 
 
 
 
33de980
 
4390904
 
 
 
33de980
 
4390904
33de980
 
 
 
 
 
 
 
4390904
33de980
 
4390904
33de980
d005da4
33de980
d005da4
33de980
4390904
814b6ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4390904
814b6ba
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import io
import os
import sys

from fastapi import FastAPI, File, UploadFile
import gradio as gr
import requests
from typing import List
import torch
from pdf2image import convert_from_path
from PIL import Image
from torch.utils.data import DataLoader
from transformers import AutoProcessor

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), './colpali-main')))

from colpali_engine.models.paligemma_colbert_architecture import ColPali
from colpali_engine.trainer.retrieval_evaluator import CustomEvaluator
from colpali_engine.utils.colpali_processing_utils import (
    process_images,
    process_queries,
)

app = FastAPI()

# Load model
model_name = "vidore/colpali"
token = os.environ.get("HF_TOKEN")
model = ColPali.from_pretrained(
    "google/paligemma-3b-mix-448", torch_dtype=torch.bfloat16, device_map="cpu", token = token).eval()

model.load_adapter(model_name)
processor = AutoProcessor.from_pretrained(model_name, token = token)
device = "cuda:0" if torch.cuda.is_available() else "cpu"
if device != model.device:
    model.to(device)
mock_image = Image.new("RGB", (448, 448), (255, 255, 255))

# In-memory storage
ds = []
images = []

@app.post("/index")
async def index(files: List[UploadFile] = File(...)):
    global ds, images
    images = []
    ds = []
    for file in files:
        content = await file.read()
        pdf_image_list = convert_from_path(io.BytesIO(content))
        images.extend(pdf_image_list)
    
    dataloader = DataLoader(
        images,
        batch_size=4,
        shuffle=False,
        collate_fn=lambda x: process_images(processor, x),
    )
    for batch_doc in dataloader:
        with torch.no_grad():
            batch_doc = {k: v.to(device) for k, v in batch_doc.items()}
            embeddings_doc = model(**batch_doc)
        ds.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
    
    return {"message": f"Uploaded and converted {len(images)} pages"}

@app.post("/search")
async def search(query: str, k: int):
    qs = []
    with torch.no_grad():
        batch_query = process_queries(processor, [query], mock_image)
        batch_query = {k: v.to(device) for k, v in batch_query.items()}
        embeddings_query = model(**batch_query)
        qs.extend(list(torch.unbind(embeddings_query.to("cpu"))))

    retriever_evaluator = CustomEvaluator(is_multi_vector=True)
    scores = retriever_evaluator.evaluate(qs, ds)

    top_k_indices = scores.argsort(axis=1)[0][-k:][::-1]

    results = [{"page": idx, "image": "image_placeholder"} for idx in top_k_indices]

    return {"results": results}

def index_gradio(file, ds):
    """Upload PDFs and get embeddings."""
    url = "http://localhost:8082/index"
    files = [("files", (f.name, f.file)) for f in file]
    response = requests.post(url, files=files)
    result = response.json()
    return result['message'], ds, []

def search_gradio(query: str, ds, images, k):
    """Send a search query and get results."""
    url = "http://localhost:8082/search"
    payload = {'query': query, 'k': k}
    response = requests.post(url, json=payload)
    result = response.json()
    results = result['results']
    return results

with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown("# ColPali: Efficient Document Retrieval with Vision Language Models 📚")

    with gr.Row():
        with gr.Column(scale=2):
            gr.Markdown("## 1️⃣ Upload PDFs")
            file = gr.File(file_types=["pdf"], file_count="multiple", label="Upload PDFs")

            convert_button = gr.Button("🔄 Index documents")
            message = gr.Textbox("Files not yet uploaded", label="Status")
            embeds = gr.State(value=[])
            imgs = gr.State(value=[])

        with gr.Column(scale=3):
            gr.Markdown("## 2️⃣ Search")
            query = gr.Textbox(placeholder="Enter your query here", label="Query")
            k = gr.Slider(minimum=1, maximum=10, step=1, label="Number of results", value=1)

    # Define the actions
    search_button = gr.Button("🔍 Search", variant="primary")
    output_gallery = gr.Gallery(label="Retrieved Documents", height=600, show_label=True)

    convert_button.click(index_gradio, inputs=[file, embeds], outputs=[message, embeds, imgs])
    search_button.click(search_gradio, inputs=[query, embeds, imgs, k], outputs=[output_gallery])

if __name__ == "__main__":
    demo.queue(max_size=10).launch(debug=True)