thomasht86's picture
Upload folder using huggingface_hub
8ce4d25 verified
#!/usr/bin/env python3
import os
import torch
from torch.utils.data import DataLoader
from PIL import Image
import numpy as np
from typing import cast
import asyncio
from colpali_engine.models import ColPali, ColPaliProcessor
from colpali_engine.utils.torch_utils import get_torch_device
from vespa.application import Vespa
from vespa.io import VespaQueryResponse
from dotenv import load_dotenv
from pathlib import Path
MAX_QUERY_TERMS = 64
SAVEDIR = Path(__file__) / "output" / "images"
load_dotenv()
def process_queries(processor, queries, image):
inputs = processor(
images=[image] * len(queries), text=queries, return_tensors="pt", padding=True
)
return inputs
def display_query_results(query, response, hits=5):
query_time = response.json.get("timing", {}).get("searchtime", -1)
query_time = round(query_time, 2)
count = response.json.get("root", {}).get("fields", {}).get("totalCount", 0)
result_text = f"Query text: '{query}', query time {query_time}s, count={count}, top results:\n"
for i, hit in enumerate(response.hits[:hits]):
title = hit["fields"]["title"]
url = hit["fields"]["url"]
page = hit["fields"]["page_number"]
image = hit["fields"]["image"]
_id = hit["id"]
score = hit["relevance"]
result_text += f"\nPDF Result {i + 1}\n"
result_text += f"Title: {title}, page {page+1} with score {score:.2f}\n"
result_text += f"URL: {url}\n"
result_text += f"ID: {_id}\n"
# Optionally, save or display the image
# img_data = base64.b64decode(image)
# img_path = SAVEDIR / f"{title}.png"
# with open(f"{img_path}", "wb") as f:
# f.write(img_data)
print(result_text)
async def query_vespa_default(app, queries, qs):
async with app.asyncio(connections=1, total_timeout=120) as session:
for idx, query in enumerate(queries):
query_embedding = {k: v.tolist() for k, v in enumerate(qs[idx])}
response: VespaQueryResponse = await session.query(
yql="select documentid,title,url,image,page_number from pdf_page where userInput(@userQuery)",
ranking="default",
userQuery=query,
timeout=120,
hits=3,
body={"input.query(qt)": query_embedding, "presentation.timing": True},
)
assert response.is_successful()
display_query_results(query, response)
async def query_vespa_nearest_neighbor(app, queries, qs):
# Using nearestNeighbor for retrieval
target_hits_per_query_tensor = (
20 # this is a hyper parameter that can be tuned for speed versus accuracy
)
async with app.asyncio(connections=1, total_timeout=180) as session:
for idx, query in enumerate(queries):
float_query_embedding = {k: v.tolist() for k, v in enumerate(qs[idx])}
binary_query_embeddings = dict()
for k, v in float_query_embedding.items():
binary_vector = (
np.packbits(np.where(np.array(v) > 0, 1, 0))
.astype(np.int8)
.tolist()
)
binary_query_embeddings[k] = binary_vector
if len(binary_query_embeddings) >= MAX_QUERY_TERMS:
print(
f"Warning: Query has more than {MAX_QUERY_TERMS} terms. Truncating."
)
break
# The mixed tensors used in MaxSim calculations
# We use both binary and float representations
query_tensors = {
"input.query(qtb)": binary_query_embeddings,
"input.query(qt)": float_query_embedding,
}
# The query tensors used in the nearest neighbor calculations
for i in range(0, len(binary_query_embeddings)):
query_tensors[f"input.query(rq{i})"] = binary_query_embeddings[i]
nn = []
for i in range(0, len(binary_query_embeddings)):
nn.append(
f"({{targetHits:{target_hits_per_query_tensor}}}nearestNeighbor(embedding,rq{i}))"
)
# We use an OR operator to combine the nearest neighbor operator
nn = " OR ".join(nn)
response: VespaQueryResponse = await session.query(
body={
**query_tensors,
"presentation.timing": True,
"yql": f"select documentid, title, url, image, page_number from pdf_page where {nn}",
"ranking.profile": "retrieval-and-rerank",
"timeout": 120,
"hits": 3,
},
)
assert response.is_successful(), response.json
display_query_results(query, response)
def main():
vespa_app_url = os.environ.get(
"VESPA_APP_URL"
) # Ensure this is set to your Vespa app URL
vespa_cloud_secret_token = os.environ.get("VESPA_CLOUD_SECRET_TOKEN")
if not vespa_app_url or not vespa_cloud_secret_token:
raise ValueError(
"Please set the VESPA_APP_URL and VESPA_CLOUD_SECRET_TOKEN environment variables"
)
# Instantiate Vespa connection
app = Vespa(url=vespa_app_url, vespa_cloud_secret_token=vespa_cloud_secret_token)
status_resp = app.get_application_status()
if status_resp.status_code != 200:
print(f"Failed to connect to Vespa at {vespa_app_url}")
return
else:
print(f"Connected to Vespa at {vespa_app_url}")
# Load the model
device = get_torch_device("auto")
print(f"Using device: {device}")
model_name = "vidore/colpali-v1.2"
processor_name = "google/paligemma-3b-mix-448"
model = cast(
ColPali,
ColPali.from_pretrained(
model_name,
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
device_map=device,
),
).eval()
processor = cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(processor_name))
# Create dummy image
dummy_image = Image.new("RGB", (448, 448), (255, 255, 255))
# Define queries
queries = [
"Percentage of non-fresh water as source?",
"Policies related to nature risk?",
"How much of produced water is recycled?",
]
# Obtain query embeddings
dataloader = DataLoader(
queries,
batch_size=1,
shuffle=False,
collate_fn=lambda x: process_queries(processor, x, dummy_image),
)
qs = []
for batch_query in dataloader:
with torch.no_grad():
batch_query = {k: v.to(model.device) for k, v in batch_query.items()}
embeddings_query = model(**batch_query)
qs.extend(list(torch.unbind(embeddings_query.to("cpu"))))
# Perform queries using default rank profile
print("Performing queries using default rank profile:")
asyncio.run(query_vespa_default(app, queries, qs))
# Perform queries using nearestNeighbor
print("Performing queries using nearestNeighbor:")
asyncio.run(query_vespa_nearest_neighbor(app, queries, qs))
if __name__ == "__main__":
main()