File size: 7,477 Bytes
8ce4d25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
#!/usr/bin/env python3

import argparse
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from io import BytesIO
from typing import cast
import os
import json
import hashlib

from colpali_engine.models import ColPali, ColPaliProcessor
from colpali_engine.utils.torch_utils import get_torch_device
from vidore_benchmark.utils.image_utils import scale_image, get_base64_image
import requests
from pdf2image import convert_from_path
from pypdf import PdfReader
import numpy as np
from vespa.application import Vespa
from vespa.io import VespaResponse
from dotenv import load_dotenv

load_dotenv()


def main():
    parser = argparse.ArgumentParser(description="Feed data into Vespa application")
    parser.add_argument(
        "--application_name",
        required=True,
        default="colpalidemo",
        help="Vespa application name",
    )
    parser.add_argument(
        "--vespa_schema_name",
        required=True,
        default="pdf_page",
        help="Vespa schema name",
    )
    args = parser.parse_args()

    vespa_app_url = os.getenv("VESPA_APP_URL")
    vespa_cloud_secret_token = os.getenv("VESPA_CLOUD_SECRET_TOKEN")
    # Set application and schema names
    application_name = args.application_name
    schema_name = args.vespa_schema_name
    # Instantiate Vespa connection using token
    app = Vespa(url=vespa_app_url, vespa_cloud_secret_token=vespa_cloud_secret_token)
    app.get_application_status()
    model_name = "vidore/colpali-v1.2"

    device = get_torch_device("auto")
    print(f"Using device: {device}")

    # Load the model
    model = cast(
        ColPali,
        ColPali.from_pretrained(
            model_name,
            torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
            device_map=device,
        ),
    ).eval()

    # Load the processor
    processor = cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name))

    # Define functions to work with PDFs
    def download_pdf(url):
        response = requests.get(url)
        if response.status_code == 200:
            return BytesIO(response.content)
        else:
            raise Exception(
                f"Failed to download PDF: Status code {response.status_code}"
            )

    def get_pdf_images(pdf_url):
        # Download the PDF
        pdf_file = download_pdf(pdf_url)
        # Save the PDF temporarily to disk (pdf2image requires a file path)
        temp_file = "temp.pdf"
        with open(temp_file, "wb") as f:
            f.write(pdf_file.read())
        reader = PdfReader(temp_file)
        page_texts = []
        for page_number in range(len(reader.pages)):
            page = reader.pages[page_number]
            text = page.extract_text()
            page_texts.append(text)
        images = convert_from_path(temp_file)
        assert len(images) == len(page_texts)
        return (images, page_texts)

    # Define sample PDFs
    sample_pdfs = [
        {
            "title": "ConocoPhillips Sustainability Highlights - Nature (24-0976)",
            "url": "https://static.conocophillips.com/files/resources/24-0976-sustainability-highlights_nature.pdf",
        },
        {
            "title": "ConocoPhillips Managing Climate Related Risks",
            "url": "https://static.conocophillips.com/files/resources/conocophillips-2023-managing-climate-related-risks.pdf",
        },
        {
            "title": "ConocoPhillips 2023 Sustainability Report",
            "url": "https://static.conocophillips.com/files/resources/conocophillips-2023-sustainability-report.pdf",
        },
    ]

    # Check if vespa_feed.json exists
    if os.path.exists("vespa_feed.json"):
        print("Loading vespa_feed from vespa_feed.json")
        with open("vespa_feed.json", "r") as f:
            vespa_feed_saved = json.load(f)
        vespa_feed = []
        for doc in vespa_feed_saved:
            put_id = doc["put"]
            fields = doc["fields"]
            # Extract document_id from put_id
            # Format: 'id:application_name:schema_name::document_id'
            parts = put_id.split("::")
            document_id = parts[1] if len(parts) > 1 else ""
            page = {"id": document_id, "fields": fields}
            vespa_feed.append(page)
    else:
        print("Generating vespa_feed")
        # Process PDFs
        for pdf in sample_pdfs:
            page_images, page_texts = get_pdf_images(pdf["url"])
            pdf["images"] = page_images
            pdf["texts"] = page_texts

        # Generate embeddings
        for pdf in sample_pdfs:
            page_embeddings = []
            dataloader = DataLoader(
                pdf["images"],
                batch_size=2,
                shuffle=False,
                collate_fn=lambda x: processor.process_images(x),
            )
            for batch_doc in tqdm(dataloader):
                with torch.no_grad():
                    batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()}
                    embeddings_doc = model(**batch_doc)
                    page_embeddings.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
            pdf["embeddings"] = page_embeddings

        # Prepare Vespa feed
        vespa_feed = []
        for pdf in sample_pdfs:
            url = pdf["url"]
            title = pdf["title"]
            for page_number, (page_text, embedding, image) in enumerate(
                zip(pdf["texts"], pdf["embeddings"], pdf["images"])
            ):
                base_64_image = get_base64_image(
                    scale_image(image, 640), add_url_prefix=False
                )
                embedding_dict = dict()
                for idx, patch_embedding in enumerate(embedding):
                    binary_vector = (
                        np.packbits(np.where(patch_embedding > 0, 1, 0))
                        .astype(np.int8)
                        .tobytes()
                        .hex()
                    )
                    embedding_dict[idx] = binary_vector
                # id_hash should be md5 hash of url and page_number
                id_hash = hashlib.md5(f"{url}_{page_number}".encode()).hexdigest()
                page = {
                    "id": id_hash,
                    "fields": {
                        "id": id_hash,
                        "url": url,
                        "title": title,
                        "page_number": page_number,
                        "image": base_64_image,
                        "text": page_text,
                        "embedding": embedding_dict,
                    },
                }
                vespa_feed.append(page)

        # Save vespa_feed to vespa_feed.json in the specified format
        vespa_feed_to_save = []
        for page in vespa_feed:
            document_id = page["id"]
            put_id = f"id:{application_name}:{schema_name}::{document_id}"
            vespa_feed_to_save.append({"put": put_id, "fields": page["fields"]})
        with open("vespa_feed.json", "w") as f:
            json.dump(vespa_feed_to_save, f)

    def callback(response: VespaResponse, id: str):
        if not response.is_successful():
            print(
                f"Failed to feed document {id} with status code {response.status_code}: Reason {response.get_json()}"
            )

    # Feed data into Vespa
    app.feed_iterable(vespa_feed, schema=schema_name, callback=callback)


if __name__ == "__main__":
    main()