Spaces:
Runtime error
Runtime error
File size: 4,239 Bytes
602d806 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
import gradio as gr
from pdf2image import convert_from_path
import torch
from PIL import Image
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import AutoProcessor
from custom_colbert.models.paligemma_colbert_architecture import ColPali
from custom_colbert.trainer.retrieval_evaluator import CustomEvaluator
def process_images(processor, images, max_length: int = 50):
texts_doc = ["Describe the image."] * len(images)
images = [image.convert("RGB") for image in images]
batch_doc = processor(
text=texts_doc,
images=images,
return_tensors="pt",
padding="longest",
max_length=max_length + processor.image_seq_length,
)
return batch_doc
def process_queries(processor, queries, mock_image, max_length: int = 50):
texts_query = []
for query in queries:
query = f"Question: {query}<unused0><unused0><unused0><unused0><unused0>"
texts_query.append(query)
batch_query = processor(
images=[mock_image.convert("RGB")] * len(texts_query),
# NOTE: the image is not used in batch_query but it is required for calling the processor
text=texts_query,
return_tensors="pt",
padding="longest",
max_length=max_length + processor.image_seq_length,
)
del batch_query["pixel_values"]
batch_query["input_ids"] = batch_query["input_ids"][..., processor.image_seq_length :]
batch_query["attention_mask"] = batch_query["attention_mask"][..., processor.image_seq_length :]
return batch_query
def search(query: str, ds, images) -> str:
qs = []
with torch.no_grad():
batch_query = process_queries(processor, [query], mock_image)
batch_query = {k: v.to(device) for k, v in batch_query.items()}
embeddings_query = model(**batch_query)
qs.extend(list(torch.unbind(embeddings_query.to("cpu"))))
# run evaluation
retriever_evaluator = CustomEvaluator(is_multi_vector=True)
scores = retriever_evaluator.evaluate(qs, ds)
return f"The most relevant page is {scores.argmax(axis=1)}", images[scores.argmax(axis=1)]
# return f"Query: {query}, most relevant page: 1, {len(ds)}", images[1]
def index(file):
"""Example script to run inference with ColPali"""
images = []
for f in file:
images.extend(convert_from_path(f))
# run inference - docs
dataloader = DataLoader(
images,
batch_size=4,
shuffle=False,
collate_fn=lambda x: process_images(processor, x),
)
ds = ["test", "double test"]
for batch_doc in tqdm(dataloader):
with torch.no_grad():
batch_doc = {k: v.to(device) for k, v in batch_doc.items()}
embeddings_doc = model(**batch_doc)
ds.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
return f"Uploaded and converted {len(images)} pages", ds, images
COLORS = ['#4285f4', '#db4437', '#f4b400', '#0f9d58', '#e48ef1']
# Load model
model_name = "coldoc/colpali-3b-mix-448"
model = ColPali.from_pretrained("google/paligemma-3b-mix-448", torch_dtype=torch.bfloat16, device_map="cuda").eval()
model.load_adapter(model_name)
processor = AutoProcessor.from_pretrained(model_name)
device = model.device
mock_image = Image.new("RGB", (448, 448), (255, 255, 255))
with gr.Blocks() as demo:
gr.Markdown("# PDF to π€ Dataset")
gr.Markdown("## 1οΈβ£ Upload PDFs")
file = gr.File(file_types=["pdf"], file_count="multiple")
gr.Markdown("## 2οΈβ£ Convert the PDFs and upload")
convert_button = gr.Button("π Convert and upload")
message = gr.Textbox("Files not yet uploaded")
embeds = gr.State()
imgs = gr.State()
# Define the actions
convert_button.click(
index,
inputs=[file],
outputs=[message, embeds, imgs]
)
gr.Markdown("## 3οΈβ£ Search")
query = gr.Textbox(placeholder="Enter your query here")
search_button = gr.Button("π Search")
message2 = gr.Textbox("Query not yet set")
output_img = gr.Image()
search_button.click(
search, inputs=[query, embeds, imgs],
outputs=[message2, output_img]
)
if __name__ == "__main__":
demo.queue(max_size=10).launch(debug=True) |