Spaces:
Running
on
T4
Running
on
T4
File size: 7,477 Bytes
8ce4d25 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 |
#!/usr/bin/env python3
import argparse
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from io import BytesIO
from typing import cast
import os
import json
import hashlib
from colpali_engine.models import ColPali, ColPaliProcessor
from colpali_engine.utils.torch_utils import get_torch_device
from vidore_benchmark.utils.image_utils import scale_image, get_base64_image
import requests
from pdf2image import convert_from_path
from pypdf import PdfReader
import numpy as np
from vespa.application import Vespa
from vespa.io import VespaResponse
from dotenv import load_dotenv
load_dotenv()
def main():
parser = argparse.ArgumentParser(description="Feed data into Vespa application")
parser.add_argument(
"--application_name",
required=True,
default="colpalidemo",
help="Vespa application name",
)
parser.add_argument(
"--vespa_schema_name",
required=True,
default="pdf_page",
help="Vespa schema name",
)
args = parser.parse_args()
vespa_app_url = os.getenv("VESPA_APP_URL")
vespa_cloud_secret_token = os.getenv("VESPA_CLOUD_SECRET_TOKEN")
# Set application and schema names
application_name = args.application_name
schema_name = args.vespa_schema_name
# Instantiate Vespa connection using token
app = Vespa(url=vespa_app_url, vespa_cloud_secret_token=vespa_cloud_secret_token)
app.get_application_status()
model_name = "vidore/colpali-v1.2"
device = get_torch_device("auto")
print(f"Using device: {device}")
# Load the model
model = cast(
ColPali,
ColPali.from_pretrained(
model_name,
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
device_map=device,
),
).eval()
# Load the processor
processor = cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name))
# Define functions to work with PDFs
def download_pdf(url):
response = requests.get(url)
if response.status_code == 200:
return BytesIO(response.content)
else:
raise Exception(
f"Failed to download PDF: Status code {response.status_code}"
)
def get_pdf_images(pdf_url):
# Download the PDF
pdf_file = download_pdf(pdf_url)
# Save the PDF temporarily to disk (pdf2image requires a file path)
temp_file = "temp.pdf"
with open(temp_file, "wb") as f:
f.write(pdf_file.read())
reader = PdfReader(temp_file)
page_texts = []
for page_number in range(len(reader.pages)):
page = reader.pages[page_number]
text = page.extract_text()
page_texts.append(text)
images = convert_from_path(temp_file)
assert len(images) == len(page_texts)
return (images, page_texts)
# Define sample PDFs
sample_pdfs = [
{
"title": "ConocoPhillips Sustainability Highlights - Nature (24-0976)",
"url": "https://static.conocophillips.com/files/resources/24-0976-sustainability-highlights_nature.pdf",
},
{
"title": "ConocoPhillips Managing Climate Related Risks",
"url": "https://static.conocophillips.com/files/resources/conocophillips-2023-managing-climate-related-risks.pdf",
},
{
"title": "ConocoPhillips 2023 Sustainability Report",
"url": "https://static.conocophillips.com/files/resources/conocophillips-2023-sustainability-report.pdf",
},
]
# Check if vespa_feed.json exists
if os.path.exists("vespa_feed.json"):
print("Loading vespa_feed from vespa_feed.json")
with open("vespa_feed.json", "r") as f:
vespa_feed_saved = json.load(f)
vespa_feed = []
for doc in vespa_feed_saved:
put_id = doc["put"]
fields = doc["fields"]
# Extract document_id from put_id
# Format: 'id:application_name:schema_name::document_id'
parts = put_id.split("::")
document_id = parts[1] if len(parts) > 1 else ""
page = {"id": document_id, "fields": fields}
vespa_feed.append(page)
else:
print("Generating vespa_feed")
# Process PDFs
for pdf in sample_pdfs:
page_images, page_texts = get_pdf_images(pdf["url"])
pdf["images"] = page_images
pdf["texts"] = page_texts
# Generate embeddings
for pdf in sample_pdfs:
page_embeddings = []
dataloader = DataLoader(
pdf["images"],
batch_size=2,
shuffle=False,
collate_fn=lambda x: processor.process_images(x),
)
for batch_doc in tqdm(dataloader):
with torch.no_grad():
batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()}
embeddings_doc = model(**batch_doc)
page_embeddings.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
pdf["embeddings"] = page_embeddings
# Prepare Vespa feed
vespa_feed = []
for pdf in sample_pdfs:
url = pdf["url"]
title = pdf["title"]
for page_number, (page_text, embedding, image) in enumerate(
zip(pdf["texts"], pdf["embeddings"], pdf["images"])
):
base_64_image = get_base64_image(
scale_image(image, 640), add_url_prefix=False
)
embedding_dict = dict()
for idx, patch_embedding in enumerate(embedding):
binary_vector = (
np.packbits(np.where(patch_embedding > 0, 1, 0))
.astype(np.int8)
.tobytes()
.hex()
)
embedding_dict[idx] = binary_vector
# id_hash should be md5 hash of url and page_number
id_hash = hashlib.md5(f"{url}_{page_number}".encode()).hexdigest()
page = {
"id": id_hash,
"fields": {
"id": id_hash,
"url": url,
"title": title,
"page_number": page_number,
"image": base_64_image,
"text": page_text,
"embedding": embedding_dict,
},
}
vespa_feed.append(page)
# Save vespa_feed to vespa_feed.json in the specified format
vespa_feed_to_save = []
for page in vespa_feed:
document_id = page["id"]
put_id = f"id:{application_name}:{schema_name}::{document_id}"
vespa_feed_to_save.append({"put": put_id, "fields": page["fields"]})
with open("vespa_feed.json", "w") as f:
json.dump(vespa_feed_to_save, f)
def callback(response: VespaResponse, id: str):
if not response.is_successful():
print(
f"Failed to feed document {id} with status code {response.status_code}: Reason {response.get_json()}"
)
# Feed data into Vespa
app.feed_iterable(vespa_feed, schema=schema_name, callback=callback)
if __name__ == "__main__":
main()
|