thomasht86 commited on
Commit
a0b3781
·
verified ·
1 Parent(s): 295263a

Upload folder using huggingface_hub

Browse files
Files changed (4) hide show
  1. README.md +0 -8
  2. backend/colpali.py +2 -288
  3. backend/vespa_app.py +326 -17
  4. main.py +5 -11
README.md CHANGED
@@ -120,14 +120,6 @@ To feed the data, run:
120
  python feed_vespa.py --vespa_app_url https://myapp.z.vespa-app.cloud --vespa_cloud_secret_token mysecrettoken
121
  ```
122
 
123
- ### Connecting to the Vespa app and querying
124
-
125
- As a first step, you can run the `query_vespa.py` script to run some sample queries against the Vespa app:
126
-
127
- ```bash
128
- python query_vespa.py
129
- ```
130
-
131
  ### Starting the front-end
132
 
133
  ```bash
 
120
  python feed_vespa.py --vespa_app_url https://myapp.z.vespa-app.cloud --vespa_cloud_secret_token mysecrettoken
121
  ```
122
 
 
 
 
 
 
 
 
 
123
  ### Starting the front-end
124
 
125
  ```bash
backend/colpali.py CHANGED
@@ -13,7 +13,6 @@ import matplotlib.cm as cm
13
  import re
14
  import io
15
 
16
- import json
17
  import time
18
  import backend.testquery as testquery
19
 
@@ -24,12 +23,10 @@ from vidore_benchmark.interpretability.torch_utils import (
24
  normalize_similarity_map_per_query_token,
25
  )
26
  from vidore_benchmark.interpretability.vit_configs import VIT_CONFIG
27
- from vespa.application import Vespa
28
- from vespa.io import VespaQueryResponse
29
 
30
  matplotlib.use("Agg")
31
-
32
- MAX_QUERY_TERMS = 64
33
 
34
  COLPALI_GEMMA_MODEL_NAME = "vidore/colpaligemma-3b-pt-448-base"
35
 
@@ -62,54 +59,6 @@ def load_vit_config(model):
62
  return vit_config
63
 
64
 
65
- def save_figure(fig, filename: str = "similarity_map.png"):
66
- try:
67
- OUTPUT_DIR = Path(__file__).parent.parent / "output" / "sim_maps"
68
- OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
69
- fig.savefig(
70
- OUTPUT_DIR / filename,
71
- bbox_inches="tight",
72
- pad_inches=0,
73
- )
74
- except Exception as e:
75
- print(f"Failed to save figure: {e}")
76
-
77
-
78
- def annotate_plot(ax, query, selected_token):
79
- # Add the query text as a title over the image with opacity
80
- ax.text(
81
- 0.5,
82
- 0.95, # Adjust the position to be on the image (y=0.1 is 10% from the bottom)
83
- query,
84
- fontsize=18,
85
- color="white",
86
- ha="center",
87
- va="center",
88
- alpha=0.8, # Set opacity (1 is fully opaque, 0 is fully transparent)
89
- bbox=dict(
90
- boxstyle="round,pad=0.5", fc="black", ec="none", lw=0, alpha=0.5
91
- ), # Add a semi-transparent background
92
- transform=ax.transAxes, # Ensure the coordinates are relative to the axes
93
- )
94
-
95
- # Add annotation with the selected token over the image with opacity
96
- ax.text(
97
- 0.5,
98
- 0.05, # Position towards the top of the image
99
- f"Selected token: `{selected_token}`",
100
- fontsize=18,
101
- color="white",
102
- ha="center",
103
- va="center",
104
- alpha=0.8, # Set opacity for the text
105
- bbox=dict(
106
- boxstyle="round,pad=0.3", fc="black", ec="none", lw=0, alpha=0.5
107
- ), # Semi-transparent background
108
- transform=ax.transAxes, # Keep the coordinates relative to the axes
109
- )
110
- return ax
111
-
112
-
113
  def gen_similarity_maps(
114
  model: ColPali,
115
  processor: ColPaliProcessor,
@@ -140,11 +89,6 @@ def gen_similarity_maps(
140
 
141
  """
142
 
143
- start = time.perf_counter()
144
-
145
- # Prepare the colormap once to avoid recomputation
146
- colormap = cm.get_cmap("viridis")
147
-
148
  # Process images and store original images and sizes
149
  processed_images = []
150
  original_images = []
@@ -336,154 +280,6 @@ def get_query_embeddings_and_token_map(
336
  return q_emb, token_to_idx
337
 
338
 
339
- def format_query_results(query, response, hits=5) -> dict:
340
- query_time = response.json.get("timing", {}).get("searchtime", -1)
341
- query_time = round(query_time, 2)
342
- count = response.json.get("root", {}).get("fields", {}).get("totalCount", 0)
343
- result_text = f"Query text: '{query}', query time {query_time}s, count={count}, top results:\n"
344
- print(result_text)
345
- return response.json
346
-
347
-
348
- async def query_vespa_default(
349
- app: Vespa,
350
- query: str,
351
- q_emb: torch.Tensor,
352
- hits: int = 3,
353
- timeout: str = "10s",
354
- **kwargs,
355
- ) -> dict:
356
- async with app.asyncio(connections=1, total_timeout=120) as session:
357
- query_embedding = format_q_embs(q_emb)
358
-
359
- start = time.perf_counter()
360
- response: VespaQueryResponse = await session.query(
361
- body={
362
- "yql": "select id,title,url,blur_image,page_number,snippet,text,summaryfeatures from pdf_page where userQuery();",
363
- "ranking": "default",
364
- "query": query,
365
- "timeout": timeout,
366
- "hits": hits,
367
- "input.query(qt)": query_embedding,
368
- "presentation.timing": True,
369
- **kwargs,
370
- },
371
- )
372
- assert response.is_successful(), response.json
373
- stop = time.perf_counter()
374
- print(
375
- f"Query time + data transfer took: {stop - start} s, vespa said searchtime was {response.json.get('timing', {}).get('searchtime', -1)} s"
376
- )
377
- open("response.json", "w").write(json.dumps(response.json))
378
- return format_query_results(query, response)
379
-
380
-
381
- async def query_vespa_bm25(
382
- app: Vespa,
383
- query: str,
384
- q_emb: torch.Tensor,
385
- hits: int = 3,
386
- timeout: str = "10s",
387
- **kwargs,
388
- ) -> dict:
389
- async with app.asyncio(connections=1, total_timeout=120) as session:
390
- query_embedding = format_q_embs(q_emb)
391
-
392
- start = time.perf_counter()
393
- response: VespaQueryResponse = await session.query(
394
- body={
395
- "yql": "select id,title,url,blur_image,page_number,snippet,text,summaryfeatures from pdf_page where userQuery();",
396
- "ranking": "bm25",
397
- "query": query,
398
- "timeout": timeout,
399
- "hits": hits,
400
- "input.query(qt)": query_embedding,
401
- "presentation.timing": True,
402
- **kwargs,
403
- },
404
- )
405
- assert response.is_successful(), response.json
406
- stop = time.perf_counter()
407
- print(
408
- f"Query time + data transfer took: {stop - start} s, vespa said searchtime was {response.json.get('timing', {}).get('searchtime', -1)} s"
409
- )
410
- return format_query_results(query, response)
411
-
412
-
413
- def float_to_binary_embedding(float_query_embedding: dict) -> dict:
414
- binary_query_embeddings = {}
415
- for k, v in float_query_embedding.items():
416
- binary_vector = (
417
- np.packbits(np.where(np.array(v) > 0, 1, 0)).astype(np.int8).tolist()
418
- )
419
- binary_query_embeddings[k] = binary_vector
420
- if len(binary_query_embeddings) >= MAX_QUERY_TERMS:
421
- print(f"Warning: Query has more than {MAX_QUERY_TERMS} terms. Truncating.")
422
- break
423
- return binary_query_embeddings
424
-
425
-
426
- def create_nn_query_strings(
427
- binary_query_embeddings: dict, target_hits_per_query_tensor: int = 20
428
- ) -> Tuple[str, dict]:
429
- # Query tensors for nearest neighbor calculations
430
- nn_query_dict = {}
431
- for i in range(len(binary_query_embeddings)):
432
- nn_query_dict[f"input.query(rq{i})"] = binary_query_embeddings[i]
433
- nn = " OR ".join(
434
- [
435
- f"({{targetHits:{target_hits_per_query_tensor}}}nearestNeighbor(embedding,rq{i}))"
436
- for i in range(len(binary_query_embeddings))
437
- ]
438
- )
439
- return nn, nn_query_dict
440
-
441
-
442
- def format_q_embs(q_embs: torch.Tensor) -> dict:
443
- float_query_embedding = {k: v.tolist() for k, v in enumerate(q_embs)}
444
- return float_query_embedding
445
-
446
-
447
- async def query_vespa_nearest_neighbor(
448
- app: Vespa,
449
- query: str,
450
- q_emb: torch.Tensor,
451
- target_hits_per_query_tensor: int = 20,
452
- hits: int = 3,
453
- timeout: str = "10s",
454
- **kwargs,
455
- ) -> dict:
456
- # Hyperparameter for speed vs. accuracy
457
- async with app.asyncio(connections=1, total_timeout=180) as session:
458
- float_query_embedding = format_q_embs(q_emb)
459
- binary_query_embeddings = float_to_binary_embedding(float_query_embedding)
460
-
461
- # Mixed tensors for MaxSim calculations
462
- query_tensors = {
463
- "input.query(qtb)": binary_query_embeddings,
464
- "input.query(qt)": float_query_embedding,
465
- }
466
- nn_string, nn_query_dict = create_nn_query_strings(
467
- binary_query_embeddings, target_hits_per_query_tensor
468
- )
469
- query_tensors.update(nn_query_dict)
470
- response: VespaQueryResponse = await session.query(
471
- body={
472
- **query_tensors,
473
- "presentation.timing": True,
474
- # if we use rank({nn_string}, userQuery()), dynamic summary doesn't work, see https://github.com/vespa-engine/vespa/issues/28704
475
- "yql": f"select id,title,snippet,text,url,blur_image,page_number,summaryfeatures from pdf_page where {nn_string} or userQuery()",
476
- "ranking.profile": "retrieval-and-rerank",
477
- "timeout": timeout,
478
- "hits": hits,
479
- "query": query,
480
- **kwargs,
481
- },
482
- )
483
- assert response.is_successful(), response.json
484
- return format_query_results(query, response)
485
-
486
-
487
  def is_special_token(token: str) -> bool:
488
  # Pattern for tokens that start with '<', numbers, whitespace, or single characters, or the string 'Question'
489
  # Will exclude these tokens from the similarity map generation
@@ -492,55 +288,6 @@ def is_special_token(token: str) -> bool:
492
  return True
493
  return False
494
 
495
- async def get_full_image_from_vespa(
496
- app: Vespa,
497
- id: str) -> str:
498
- async with app.asyncio(connections=1, total_timeout=120) as session:
499
- start = time.perf_counter()
500
- response: VespaQueryResponse = await session.query(
501
- body={
502
- "yql": f"select full_image from pdf_page where id contains \"{id}\"",
503
- "ranking": "unranked",
504
- "presentation.timing": True,
505
- },
506
- )
507
- assert response.is_successful(), response.json
508
- stop = time.perf_counter()
509
- print(
510
- f"Getting image from Vespa took: {stop - start} s, vespa said searchtime was {response.json.get('timing', {}).get('searchtime', -1)} s"
511
- )
512
- return response.json["root"]["children"][0]["fields"]["full_image"]
513
-
514
- async def get_result_from_query(
515
- app: Vespa,
516
- processor: ColPaliProcessor,
517
- model: ColPali,
518
- query: str,
519
- q_embs: torch.Tensor,
520
- token_to_idx: Dict[str, int],
521
- ranking: str,
522
- ) -> Dict[str, Any]:
523
- # Get the query embeddings and token map
524
- print(query)
525
-
526
- print(token_to_idx)
527
- if ranking == "nn+colpali":
528
- result = await query_vespa_nearest_neighbor(app, query, q_embs)
529
- elif ranking == "bm25+colpali":
530
- result = await query_vespa_default(app, query, q_embs)
531
- elif ranking == "bm25":
532
- result = await query_vespa_bm25(app, query, q_embs)
533
- else:
534
- raise ValueError(f"Unsupported ranking: {ranking}")
535
- # Print score, title id, and text of the results
536
- for idx, child in enumerate(result["root"]["children"]):
537
- print(
538
- f"Result {idx+1}: {child['relevance']}, {child['fields']['title']}, {child['fields']['id']}"
539
- )
540
- for single_result in result["root"]["children"]:
541
- print(single_result["fields"].keys())
542
- return result
543
-
544
 
545
  def add_sim_maps_to_result(
546
  result: Dict[str, Any],
@@ -582,36 +329,3 @@ def add_sim_maps_to_result(
582
  # for token, sim_mapb64 in sim_map_dict.items():
583
  # single_result["fields"][f"sim_map_{token}"] = sim_mapb64
584
  return result
585
-
586
-
587
- if __name__ == "__main__":
588
- model, processor = load_model()
589
- vit_config = load_vit_config(model)
590
- query = "How many percent of source water is fresh water?"
591
- image_filepath = (
592
- Path(__file__).parent.parent
593
- / "static"
594
- / "assets"
595
- / "ConocoPhillips Sustainability Highlights - Nature (24-0976).png"
596
- )
597
- q_embs, token_to_idx = get_query_embeddings_and_token_map(
598
- processor,
599
- model,
600
- query,
601
- )
602
- figs_images = gen_similarity_maps(
603
- model,
604
- processor,
605
- model.device,
606
- vit_config,
607
- query=query,
608
- query_embs=q_embs,
609
- token_idx_map=token_to_idx,
610
- images=[image_filepath],
611
- vespa_sim_maps=None,
612
- )
613
- for fig_token in figs_images:
614
- for token, (fig, ax) in fig_token.items():
615
- print(f"Token: {token}")
616
- save_figure(fig, f"similarity_map_{token}.png")
617
- print("Done")
 
13
  import re
14
  import io
15
 
 
16
  import time
17
  import backend.testquery as testquery
18
 
 
23
  normalize_similarity_map_per_query_token,
24
  )
25
  from vidore_benchmark.interpretability.vit_configs import VIT_CONFIG
 
 
26
 
27
  matplotlib.use("Agg")
28
+ # Prepare the colormap once to avoid recomputation
29
+ colormap = cm.get_cmap("viridis")
30
 
31
  COLPALI_GEMMA_MODEL_NAME = "vidore/colpaligemma-3b-pt-448-base"
32
 
 
59
  return vit_config
60
 
61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  def gen_similarity_maps(
63
  model: ColPali,
64
  processor: ColPaliProcessor,
 
89
 
90
  """
91
 
 
 
 
 
 
92
  # Process images and store original images and sizes
93
  processed_images = []
94
  original_images = []
 
280
  return q_emb, token_to_idx
281
 
282
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
283
  def is_special_token(token: str) -> bool:
284
  # Pattern for tokens that start with '<', numbers, whitespace, or single characters, or the string 'Question'
285
  # Will exclude these tokens from the similarity map generation
 
288
  return True
289
  return False
290
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
291
 
292
  def add_sim_maps_to_result(
293
  result: Dict[str, Any],
 
329
  # for token, sim_mapb64 in sim_map_dict.items():
330
  # single_result["fields"][f"sim_map_{token}"] = sim_mapb64
331
  return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
backend/vespa_app.py CHANGED
@@ -1,23 +1,332 @@
1
  import os
2
- from vespa.application import Vespa
 
 
 
 
3
  from dotenv import load_dotenv
 
 
 
 
 
 
 
 
4
 
 
 
 
 
 
 
 
5
 
6
- def get_vespa_app():
7
- load_dotenv()
8
- vespa_app_url = os.environ.get(
9
- "VESPA_APP_URL"
10
- ) # Ensure this is set to your Vespa app URL
11
- vespa_cloud_secret_token = os.environ.get("VESPA_CLOUD_SECRET_TOKEN")
12
 
13
- if not vespa_app_url or not vespa_cloud_secret_token:
14
- raise ValueError(
15
- "Please set the VESPA_APP_URL and VESPA_CLOUD_SECRET_TOKEN environment variables"
 
16
  )
17
- # Instantiate Vespa connection
18
- vespa_app = Vespa(
19
- url=vespa_app_url, vespa_cloud_secret_token=vespa_cloud_secret_token
20
- )
21
- vespa_app.wait_for_application_up()
22
- print(f"Connected to Vespa at {vespa_app_url}")
23
- return vespa_app
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ import time
3
+ from typing import Dict, Any, Tuple
4
+
5
+ import numpy as np
6
+ import torch
7
  from dotenv import load_dotenv
8
+ from vespa.application import Vespa
9
+ from vespa.io import VespaQueryResponse
10
+
11
+
12
+ class VespaQueryClient:
13
+ MAX_QUERY_TERMS = 64
14
+ VESPA_SCHEMA_NAME = "pdf_page"
15
+ SELECT_FIELDS = "id,title,url,blur_image,page_number,snippet,text,summaryfeatures"
16
 
17
+ def __init__(self):
18
+ """
19
+ Initialize the VespaQueryClient by loading environment variables and establishing a connection to the Vespa application.
20
+ """
21
+ load_dotenv()
22
+ self.vespa_app_url = os.environ.get("VESPA_APP_URL")
23
+ self.vespa_cloud_secret_token = os.environ.get("VESPA_CLOUD_SECRET_TOKEN")
24
 
25
+ if not self.vespa_app_url or not self.vespa_cloud_secret_token:
26
+ raise ValueError(
27
+ "Please set the VESPA_APP_URL and VESPA_CLOUD_SECRET_TOKEN environment variables"
28
+ )
 
 
29
 
30
+ # Instantiate Vespa connection
31
+ self.app = Vespa(
32
+ url=self.vespa_app_url,
33
+ vespa_cloud_secret_token=self.vespa_cloud_secret_token,
34
  )
35
+ self.app.wait_for_application_up()
36
+ print(f"Connected to Vespa at {self.vespa_app_url}")
37
+
38
+ def format_query_results(
39
+ self, query: str, response: VespaQueryResponse, hits: int = 5
40
+ ) -> dict:
41
+ """
42
+ Format the Vespa query results.
43
+
44
+ Args:
45
+ query (str): The query text.
46
+ response (VespaQueryResponse): The response from Vespa.
47
+ hits (int, optional): Number of hits to display. Defaults to 5.
48
+
49
+ Returns:
50
+ dict: The JSON content of the response.
51
+ """
52
+ query_time = response.json.get("timing", {}).get("searchtime", -1)
53
+ query_time = round(query_time, 2)
54
+ count = response.json.get("root", {}).get("fields", {}).get("totalCount", 0)
55
+ result_text = f"Query text: '{query}', query time {query_time}s, count={count}, top results:\n"
56
+ print(result_text)
57
+ return response.json
58
+
59
+ async def query_vespa_default(
60
+ self,
61
+ query: str,
62
+ q_emb: torch.Tensor,
63
+ hits: int = 3,
64
+ timeout: str = "10s",
65
+ **kwargs,
66
+ ) -> dict:
67
+ """
68
+ Query Vespa using the default ranking profile.
69
+
70
+ Args:
71
+ query (str): The query text.
72
+ q_emb (torch.Tensor): Query embeddings.
73
+ hits (int, optional): Number of hits to retrieve. Defaults to 3.
74
+ timeout (str, optional): Query timeout. Defaults to "10s".
75
+
76
+ Returns:
77
+ dict: The formatted query results.
78
+ """
79
+ async with self.app.asyncio(connections=1) as session:
80
+ query_embedding = self.format_q_embs(q_emb)
81
+
82
+ start = time.perf_counter()
83
+ response: VespaQueryResponse = await session.query(
84
+ body={
85
+ "yql": (
86
+ f"select {self.SELECT_FIELDS} from {self.VESPA_SCHEMA_NAME} where userQuery();"
87
+ ),
88
+ "ranking": "default",
89
+ "query": query,
90
+ "timeout": timeout,
91
+ "hits": hits,
92
+ "input.query(qt)": query_embedding,
93
+ "presentation.timing": True,
94
+ **kwargs,
95
+ },
96
+ )
97
+ assert response.is_successful(), response.json
98
+ stop = time.perf_counter()
99
+ print(
100
+ f"Query time + data transfer took: {stop - start} s, Vespa reported searchtime was "
101
+ f"{response.json.get('timing', {}).get('searchtime', -1)} s"
102
+ )
103
+ return self.format_query_results(query, response)
104
+
105
+ async def query_vespa_bm25(
106
+ self,
107
+ query: str,
108
+ q_emb: torch.Tensor,
109
+ hits: int = 3,
110
+ timeout: str = "10s",
111
+ **kwargs,
112
+ ) -> dict:
113
+ """
114
+ Query Vespa using the BM25 ranking profile.
115
+
116
+ Args:
117
+ query (str): The query text.
118
+ q_emb (torch.Tensor): Query embeddings.
119
+ hits (int, optional): Number of hits to retrieve. Defaults to 3.
120
+ timeout (str, optional): Query timeout. Defaults to "10s".
121
+
122
+ Returns:
123
+ dict: The formatted query results.
124
+ """
125
+ async with self.app.asyncio(connections=1) as session:
126
+ query_embedding = self.format_q_embs(q_emb)
127
+
128
+ start = time.perf_counter()
129
+ response: VespaQueryResponse = await session.query(
130
+ body={
131
+ "yql": (
132
+ f"select {self.SELECT_FIELDS} from {self.VESPA_SCHEMA_NAME} where userQuery();"
133
+ ),
134
+ "ranking": "bm25",
135
+ "query": query,
136
+ "timeout": timeout,
137
+ "hits": hits,
138
+ "input.query(qt)": query_embedding,
139
+ "presentation.timing": True,
140
+ **kwargs,
141
+ },
142
+ )
143
+ assert response.is_successful(), response.json
144
+ stop = time.perf_counter()
145
+ print(
146
+ f"Query time + data transfer took: {stop - start} s, Vespa reported searchtime was "
147
+ f"{response.json.get('timing', {}).get('searchtime', -1)} s"
148
+ )
149
+ return self.format_query_results(query, response)
150
+
151
+ def float_to_binary_embedding(self, float_query_embedding: dict) -> dict:
152
+ """
153
+ Convert float query embeddings to binary embeddings.
154
+
155
+ Args:
156
+ float_query_embedding (dict): Dictionary of float embeddings.
157
+
158
+ Returns:
159
+ dict: Dictionary of binary embeddings.
160
+ """
161
+ binary_query_embeddings = {}
162
+ for key, vector in float_query_embedding.items():
163
+ binary_vector = (
164
+ np.packbits(np.where(np.array(vector) > 0, 1, 0))
165
+ .astype(np.int8)
166
+ .tolist()
167
+ )
168
+ binary_query_embeddings[key] = binary_vector
169
+ if len(binary_query_embeddings) >= self.MAX_QUERY_TERMS:
170
+ print(
171
+ f"Warning: Query has more than {self.MAX_QUERY_TERMS} terms. Truncating."
172
+ )
173
+ break
174
+ return binary_query_embeddings
175
+
176
+ def create_nn_query_strings(
177
+ self, binary_query_embeddings: dict, target_hits_per_query_tensor: int = 20
178
+ ) -> Tuple[str, dict]:
179
+ """
180
+ Create nearest neighbor query strings for Vespa.
181
+
182
+ Args:
183
+ binary_query_embeddings (dict): Binary query embeddings.
184
+ target_hits_per_query_tensor (int, optional): Target hits per query tensor. Defaults to 20.
185
+
186
+ Returns:
187
+ Tuple[str, dict]: Nearest neighbor query string and query tensor dictionary.
188
+ """
189
+ nn_query_dict = {}
190
+ for i in range(len(binary_query_embeddings)):
191
+ nn_query_dict[f"input.query(rq{i})"] = binary_query_embeddings[i]
192
+ nn = " OR ".join(
193
+ [
194
+ f"({{targetHits:{target_hits_per_query_tensor}}}nearestNeighbor(embedding,rq{i}))"
195
+ for i in range(len(binary_query_embeddings))
196
+ ]
197
+ )
198
+ return nn, nn_query_dict
199
+
200
+ def format_q_embs(self, q_embs: torch.Tensor) -> dict:
201
+ """
202
+ Convert query embeddings to a dictionary of lists.
203
+
204
+ Args:
205
+ q_embs (torch.Tensor): Query embeddings tensor.
206
+
207
+ Returns:
208
+ dict: Dictionary where each key is an index and value is the embedding list.
209
+ """
210
+ return {idx: emb.tolist() for idx, emb in enumerate(q_embs)}
211
+
212
+ async def get_result_from_query(
213
+ self,
214
+ query: str,
215
+ q_embs: torch.Tensor,
216
+ ranking: str,
217
+ token_to_idx: dict,
218
+ ) -> Dict[str, Any]:
219
+ """
220
+ Get query results from Vespa based on the ranking method.
221
+
222
+ Args:
223
+ query (str): The query text.
224
+ q_embs (torch.Tensor): Query embeddings.
225
+ ranking (str): The ranking method to use.
226
+ token_to_idx (dict): Token to index mapping.
227
+
228
+ Returns:
229
+ Dict[str, Any]: The query results.
230
+ """
231
+ print(query)
232
+ print(token_to_idx)
233
+
234
+ if ranking == "nn+colpali":
235
+ result = await self.query_vespa_nearest_neighbor(query, q_embs)
236
+ elif ranking == "bm25+colpali":
237
+ result = await self.query_vespa_default(query, q_embs)
238
+ elif ranking == "bm25":
239
+ result = await self.query_vespa_bm25(query, q_embs)
240
+ else:
241
+ raise ValueError(f"Unsupported ranking: {ranking}")
242
+
243
+ # Print score, title id, and text of the results
244
+ for idx, child in enumerate(result["root"]["children"]):
245
+ print(
246
+ f"Result {idx+1}: {child['relevance']}, {child['fields']['title']}, {child['fields']['id']}"
247
+ )
248
+ for single_result in result["root"]["children"]:
249
+ print(single_result["fields"].keys())
250
+ return result
251
+
252
+ async def get_full_image_from_vespa(self, doc_id: str) -> str:
253
+ """
254
+ Retrieve the full image from Vespa for a given document ID.
255
+
256
+ Args:
257
+ doc_id (str): The document ID.
258
+
259
+ Returns:
260
+ str: The full image data.
261
+ """
262
+ async with self.app.asyncio(connections=1) as session:
263
+ start = time.perf_counter()
264
+ response: VespaQueryResponse = await session.query(
265
+ body={
266
+ "yql": f'select full_image from {self.VESPA_SCHEMA_NAME} where id contains "{doc_id}"',
267
+ "ranking": "unranked",
268
+ "presentation.timing": True,
269
+ },
270
+ )
271
+ assert response.is_successful(), response.json
272
+ stop = time.perf_counter()
273
+ print(
274
+ f"Getting image from Vespa took: {stop - start} s, Vespa reported searchtime was "
275
+ f"{response.json.get('timing', {}).get('searchtime', -1)} s"
276
+ )
277
+ return response.json["root"]["children"][0]["fields"]["full_image"]
278
+
279
+ async def query_vespa_nearest_neighbor(
280
+ self,
281
+ query: str,
282
+ q_emb: torch.Tensor,
283
+ target_hits_per_query_tensor: int = 20,
284
+ hits: int = 3,
285
+ timeout: str = "10s",
286
+ **kwargs,
287
+ ) -> dict:
288
+ """
289
+ Query Vespa using nearest neighbor search with mixed tensors for MaxSim calculations.
290
+
291
+ Args:
292
+ query (str): The query text.
293
+ q_emb (torch.Tensor): Query embeddings.
294
+ target_hits_per_query_tensor (int, optional): Target hits per query tensor. Defaults to 20.
295
+ hits (int, optional): Number of hits to retrieve. Defaults to 3.
296
+ timeout (str, optional): Query timeout. Defaults to "10s".
297
+
298
+ Returns:
299
+ dict: The formatted query results.
300
+ """
301
+ async with self.app.asyncio(connections=1) as session:
302
+ float_query_embedding = self.format_q_embs(q_emb)
303
+ binary_query_embeddings = self.float_to_binary_embedding(
304
+ float_query_embedding
305
+ )
306
+
307
+ # Mixed tensors for MaxSim calculations
308
+ query_tensors = {
309
+ "input.query(qtb)": binary_query_embeddings,
310
+ "input.query(qt)": float_query_embedding,
311
+ }
312
+ nn_string, nn_query_dict = self.create_nn_query_strings(
313
+ binary_query_embeddings, target_hits_per_query_tensor
314
+ )
315
+ query_tensors.update(nn_query_dict)
316
+
317
+ response: VespaQueryResponse = await session.query(
318
+ body={
319
+ **query_tensors,
320
+ "presentation.timing": True,
321
+ "yql": (
322
+ f"select {self.SELECT_FIELDS} from {self.VESPA_SCHEMA_NAME} where {nn_string} or userQuery()"
323
+ ),
324
+ "ranking.profile": "retrieval-and-rerank",
325
+ "timeout": timeout,
326
+ "hits": hits,
327
+ "query": query,
328
+ **kwargs,
329
+ },
330
+ )
331
+ assert response.is_successful(), response.json
332
+ return self.format_query_results(query, response)
main.py CHANGED
@@ -13,12 +13,10 @@ from backend.cache import LRUCache
13
  from backend.colpali import (
14
  add_sim_maps_to_result,
15
  get_query_embeddings_and_token_map,
16
- get_result_from_query,
17
  is_special_token,
18
- get_full_image_from_vespa,
19
  )
20
  from backend.modelmanager import ModelManager
21
- from backend.vespa_app import get_vespa_app
22
  from frontend.app import (
23
  ChatResult,
24
  Home,
@@ -65,8 +63,7 @@ app, rt = fast_app(
65
  sselink,
66
  ),
67
  )
68
- vespa_app: Vespa = get_vespa_app()
69
-
70
  result_cache = LRUCache(max_size=20) # Each result can be ~10MB
71
  task_cache = LRUCache(
72
  max_size=1000
@@ -173,14 +170,11 @@ async def get(request, query: str, nn: bool = True):
173
 
174
  start = time.perf_counter()
175
  # Fetch real search results from Vespa
176
- result = await get_result_from_query(
177
- app=vespa_app,
178
- processor=processor,
179
- model=model,
180
  query=query,
181
  q_embs=q_embs,
182
- token_to_idx=token_to_idx,
183
  ranking=ranking_value,
 
184
  )
185
  end = time.perf_counter()
186
  print(
@@ -278,7 +272,7 @@ async def full_image(docid: str, query_id: str, idx: int):
278
  """
279
  Endpoint to get the full quality image for a given result id.
280
  """
281
- image_data = await get_full_image_from_vespa(vespa_app, docid)
282
  # Update the cache with the full image data asynchronously to not block the request
283
  asyncio.create_task(update_full_image_cache(docid, query_id, idx, image_data))
284
  # Decode the base64 image data
 
13
  from backend.colpali import (
14
  add_sim_maps_to_result,
15
  get_query_embeddings_and_token_map,
 
16
  is_special_token,
 
17
  )
18
  from backend.modelmanager import ModelManager
19
+ from backend.vespa_app import VespaQueryClient
20
  from frontend.app import (
21
  ChatResult,
22
  Home,
 
63
  sselink,
64
  ),
65
  )
66
+ vespa_app: Vespa = VespaQueryClient()
 
67
  result_cache = LRUCache(max_size=20) # Each result can be ~10MB
68
  task_cache = LRUCache(
69
  max_size=1000
 
170
 
171
  start = time.perf_counter()
172
  # Fetch real search results from Vespa
173
+ result = await vespa_app.get_result_from_query(
 
 
 
174
  query=query,
175
  q_embs=q_embs,
 
176
  ranking=ranking_value,
177
+ token_to_idx=token_to_idx,
178
  )
179
  end = time.perf_counter()
180
  print(
 
272
  """
273
  Endpoint to get the full quality image for a given result id.
274
  """
275
+ image_data = await vespa_app.get_full_image_from_vespa(docid)
276
  # Update the cache with the full image data asynchronously to not block the request
277
  asyncio.create_task(update_full_image_cache(docid, query_id, idx, image_data))
278
  # Decode the base64 image data