andreer commited on
Commit
9e9d8e8
·
verified ·
1 Parent(s): 06120a1

Upload folder using huggingface_hub

Browse files
.gitignore CHANGED
@@ -8,4 +8,6 @@ template/
8
  *.json
9
  output/
10
  pdfs/
11
- static/saved/
 
 
 
8
  *.json
9
  output/
10
  pdfs/
11
+ static/full_images/
12
+ static/sim_maps/
13
+ embeddings/
backend/colpali.py CHANGED
@@ -7,7 +7,7 @@ from typing import cast, Generator
7
  from pathlib import Path
8
  import base64
9
  from io import BytesIO
10
- from typing import Union, Tuple, List, Dict, Any
11
  import matplotlib
12
  import matplotlib.cm as cm
13
  import re
@@ -49,7 +49,7 @@ def load_model() -> Tuple[ColPali, ColPaliProcessor]:
49
 
50
  # Load the processor
51
  processor = cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name))
52
- return model, processor
53
 
54
 
55
  def load_vit_config(model):
@@ -63,7 +63,6 @@ def gen_similarity_maps(
63
  model: ColPali,
64
  processor: ColPaliProcessor,
65
  device,
66
- vit_config,
67
  query: str,
68
  query_embs: torch.Tensor,
69
  token_idx_map: dict,
@@ -88,7 +87,7 @@ def gen_similarity_maps(
88
  Tuple[int, str, str]: A tuple containing the image index, the selected token, and the base64-encoded image.
89
 
90
  """
91
-
92
  # Process images and store original images and sizes
93
  processed_images = []
94
  original_images = []
@@ -254,7 +253,7 @@ def gen_similarity_maps(
254
 
255
  # Store the base64-encoded image
256
  result_per_image[token] = blended_img_base64
257
- yield idx, token, blended_img_base64
258
  end3 = time.perf_counter()
259
  print(f"Blending images took: {end3 - start3} s")
260
 
@@ -287,60 +286,3 @@ def is_special_token(token: str) -> bool:
287
  if (len(token) < 3) or pattern.match(token):
288
  return True
289
  return False
290
-
291
-
292
- def add_sim_maps_to_result(
293
- result: Dict[str, Any],
294
- model: ColPali,
295
- processor: ColPaliProcessor,
296
- query: str,
297
- q_embs: Any,
298
- token_to_idx: Dict[str, int],
299
- query_id: str,
300
- result_cache,
301
- ) -> Dict[str, Any]:
302
- print("Adding similarity maps to result - query_id:", query_id)
303
- vit_config = load_vit_config(model)
304
- imgs: List[str] = []
305
- vespa_sim_maps: List[str] = []
306
- for single_result in result["root"]["children"]:
307
- img = single_result["fields"]["blur_image"]
308
- if img:
309
- imgs.append(img)
310
- vespa_sim_map = single_result["fields"].get("summaryfeatures", None)
311
- if vespa_sim_map:
312
- vespa_sim_maps.append(vespa_sim_map)
313
- if not imgs:
314
- return result
315
- if len(imgs) != len(vespa_sim_maps):
316
- print(
317
- "Number of images and similarity maps do not match. Skipping similarity map generation."
318
- )
319
- return result
320
- sim_map_imgs_generator = gen_similarity_maps(
321
- model=model,
322
- processor=processor,
323
- device=model.device if hasattr(model, "device") else "cpu",
324
- vit_config=vit_config,
325
- query=query,
326
- query_embs=q_embs,
327
- token_idx_map=token_to_idx,
328
- images=imgs,
329
- vespa_sim_maps=vespa_sim_maps,
330
- )
331
- for img_idx, token, sim_mapb64 in sim_map_imgs_generator:
332
- print(f"Created sim map for image {img_idx} and token {token}")
333
- if (
334
- len(result["root"]["children"]) > img_idx
335
- and "fields" in result["root"]["children"][img_idx]
336
- ):
337
- result["root"]["children"][img_idx]["fields"][f"sim_map_{token}"] = (
338
- sim_mapb64
339
- )
340
- # Update result_cache with the new sim_map
341
- result_cache.set(query_id, result)
342
- else:
343
- print(
344
- f"Could not add sim map to result for image {img_idx} and token {token}"
345
- )
346
- return result
 
7
  from pathlib import Path
8
  import base64
9
  from io import BytesIO
10
+ from typing import Union, Tuple, List
11
  import matplotlib
12
  import matplotlib.cm as cm
13
  import re
 
49
 
50
  # Load the processor
51
  processor = cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name))
52
+ return model, processor, device
53
 
54
 
55
  def load_vit_config(model):
 
63
  model: ColPali,
64
  processor: ColPaliProcessor,
65
  device,
 
66
  query: str,
67
  query_embs: torch.Tensor,
68
  token_idx_map: dict,
 
87
  Tuple[int, str, str]: A tuple containing the image index, the selected token, and the base64-encoded image.
88
 
89
  """
90
+ vit_config = load_vit_config(model)
91
  # Process images and store original images and sizes
92
  processed_images = []
93
  original_images = []
 
253
 
254
  # Store the base64-encoded image
255
  result_per_image[token] = blended_img_base64
256
+ yield idx, token, token_idx, blended_img_base64
257
  end3 = time.perf_counter()
258
  print(f"Blending images took: {end3 - start3} s")
259
 
 
286
  if (len(token) < 3) or pattern.match(token):
287
  return True
288
  return False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
backend/modelmanager.py CHANGED
@@ -17,7 +17,7 @@ class ModelManager:
17
 
18
  def initialize_model_and_processor(self):
19
  if self.model is None or self.processor is None: # Ensure no reinitialization
20
- self.model, self.processor = load_model()
21
  if self.model is None or self.processor is None:
22
  print("Failed to initialize model or processor at startup")
23
  else:
 
17
 
18
  def initialize_model_and_processor(self):
19
  if self.model is None or self.processor is None: # Ensure no reinitialization
20
+ self.model, self.processor, self.device = load_model()
21
  if self.model is None or self.processor is None:
22
  print("Failed to initialize model or processor at startup")
23
  else:
backend/vespa_app.py CHANGED
@@ -1,18 +1,19 @@
1
  import os
2
  import time
3
  from typing import Any, Dict, Tuple
4
-
5
  import numpy as np
6
  import torch
7
  from dotenv import load_dotenv
8
  from vespa.application import Vespa
9
  from vespa.io import VespaQueryResponse
 
10
 
11
 
12
  class VespaQueryClient:
13
  MAX_QUERY_TERMS = 64
14
  VESPA_SCHEMA_NAME = "pdf_page"
15
- SELECT_FIELDS = "id,title,url,blur_image,page_number,snippet,text,summaryfeatures"
16
 
17
  def __init__(self):
18
  """
@@ -73,6 +74,12 @@ class VespaQueryClient:
73
  self.app.wait_for_application_up()
74
  print(f"Connected to Vespa at {self.vespa_app_url}")
75
 
 
 
 
 
 
 
76
  def format_query_results(
77
  self, query: str, response: VespaQueryResponse, hits: int = 5
78
  ) -> dict:
@@ -100,6 +107,7 @@ class VespaQueryClient:
100
  q_emb: torch.Tensor,
101
  hits: int = 3,
102
  timeout: str = "10s",
 
103
  **kwargs,
104
  ) -> dict:
105
  """
@@ -121,9 +129,9 @@ class VespaQueryClient:
121
  response: VespaQueryResponse = await session.query(
122
  body={
123
  "yql": (
124
- f"select {self.SELECT_FIELDS} from {self.VESPA_SCHEMA_NAME} where userQuery();"
125
  ),
126
- "ranking": "default",
127
  "query": query,
128
  "timeout": timeout,
129
  "hits": hits,
@@ -146,6 +154,7 @@ class VespaQueryClient:
146
  q_emb: torch.Tensor,
147
  hits: int = 3,
148
  timeout: str = "10s",
 
149
  **kwargs,
150
  ) -> dict:
151
  """
@@ -167,9 +176,9 @@ class VespaQueryClient:
167
  response: VespaQueryResponse = await session.query(
168
  body={
169
  "yql": (
170
- f"select {self.SELECT_FIELDS} from {self.VESPA_SCHEMA_NAME} where userQuery();"
171
  ),
172
- "ranking": "bm25",
173
  "query": query,
174
  "timeout": timeout,
175
  "hits": hits,
@@ -266,30 +275,54 @@ class VespaQueryClient:
266
  Returns:
267
  Dict[str, Any]: The query results.
268
  """
269
- print(query)
270
- print(token_to_idx)
271
-
272
- if ranking == "nn+colpali":
273
- result = await self.query_vespa_nearest_neighbor(query, q_embs)
274
- elif ranking == "bm25+colpali":
275
- result = await self.query_vespa_default(query, q_embs)
276
- elif ranking == "bm25":
277
- result = await self.query_vespa_bm25(query, q_embs)
 
278
  else:
279
- raise ValueError(f"Unsupported ranking: {ranking}")
280
-
281
  # Print score, title id, and text of the results
282
  if "root" not in result or "children" not in result["root"]:
283
  result["root"] = {"children": []}
284
  return result
285
- for idx, child in enumerate(result["root"]["children"]):
286
- print(
287
- f"Result {idx+1}: {child['relevance']}, {child['fields']['title']}, {child['fields']['id']}"
288
- )
289
  for single_result in result["root"]["children"]:
290
  print(single_result["fields"].keys())
291
  return result
292
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
293
  async def get_full_image_from_vespa(self, doc_id: str) -> str:
294
  """
295
  Retrieve the full image from Vespa for a given document ID.
@@ -317,6 +350,23 @@ class VespaQueryClient:
317
  )
318
  return response.json["root"]["children"][0]["fields"]["full_image"]
319
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
320
  async def get_suggestions(self, query: str) -> list:
321
  async with self.app.asyncio(connections=1) as session:
322
  start = time.perf_counter()
@@ -348,6 +398,12 @@ class VespaQueryClient:
348
  flat_questions = [item for sublist in questions for item in sublist]
349
  return flat_questions
350
 
 
 
 
 
 
 
351
  async def query_vespa_nearest_neighbor(
352
  self,
353
  query: str,
@@ -355,6 +411,7 @@ class VespaQueryClient:
355
  target_hits_per_query_tensor: int = 20,
356
  hits: int = 3,
357
  timeout: str = "10s",
 
358
  **kwargs,
359
  ) -> dict:
360
  """
@@ -385,15 +442,16 @@ class VespaQueryClient:
385
  binary_query_embeddings, target_hits_per_query_tensor
386
  )
387
  query_tensors.update(nn_query_dict)
388
-
389
  response: VespaQueryResponse = await session.query(
390
  body={
391
  **query_tensors,
392
  "presentation.timing": True,
393
  "yql": (
394
- f"select {self.SELECT_FIELDS} from {self.VESPA_SCHEMA_NAME} where {nn_string} or userQuery()"
 
 
 
395
  ),
396
- "ranking.profile": "retrieval-and-rerank",
397
  "timeout": timeout,
398
  "hits": hits,
399
  "query": query,
 
1
  import os
2
  import time
3
  from typing import Any, Dict, Tuple
4
+ import asyncio
5
  import numpy as np
6
  import torch
7
  from dotenv import load_dotenv
8
  from vespa.application import Vespa
9
  from vespa.io import VespaQueryResponse
10
+ from .colpali import is_special_token
11
 
12
 
13
  class VespaQueryClient:
14
  MAX_QUERY_TERMS = 64
15
  VESPA_SCHEMA_NAME = "pdf_page"
16
+ SELECT_FIELDS = "id,title,url,blur_image,page_number,snippet,text"
17
 
18
  def __init__(self):
19
  """
 
74
  self.app.wait_for_application_up()
75
  print(f"Connected to Vespa at {self.vespa_app_url}")
76
 
77
+ def get_fields(self, sim_map: bool = False):
78
+ if not sim_map:
79
+ return self.SELECT_FIELDS
80
+ else:
81
+ return "summaryfeatures"
82
+
83
  def format_query_results(
84
  self, query: str, response: VespaQueryResponse, hits: int = 5
85
  ) -> dict:
 
107
  q_emb: torch.Tensor,
108
  hits: int = 3,
109
  timeout: str = "10s",
110
+ sim_map: bool = False,
111
  **kwargs,
112
  ) -> dict:
113
  """
 
129
  response: VespaQueryResponse = await session.query(
130
  body={
131
  "yql": (
132
+ f"select {self.get_fields(sim_map=sim_map)} from {self.VESPA_SCHEMA_NAME} where userQuery();"
133
  ),
134
+ "ranking": self.get_rank_profile("default", sim_map),
135
  "query": query,
136
  "timeout": timeout,
137
  "hits": hits,
 
154
  q_emb: torch.Tensor,
155
  hits: int = 3,
156
  timeout: str = "10s",
157
+ sim_map: bool = False,
158
  **kwargs,
159
  ) -> dict:
160
  """
 
176
  response: VespaQueryResponse = await session.query(
177
  body={
178
  "yql": (
179
+ f"select {self.get_fields(sim_map=sim_map)} from {self.VESPA_SCHEMA_NAME} where userQuery();"
180
  ),
181
+ "ranking": self.get_rank_profile("bm25", sim_map),
182
  "query": query,
183
  "timeout": timeout,
184
  "hits": hits,
 
275
  Returns:
276
  Dict[str, Any]: The query results.
277
  """
278
+ rank_method = ranking.split("_")[0]
279
+ sim_map: bool = len(ranking.split("_")) > 1 and ranking.split("_")[1] == "sim"
280
+ if rank_method == "nn+colpali":
281
+ result = await self.query_vespa_nearest_neighbor(
282
+ query, q_embs, sim_map=sim_map
283
+ )
284
+ elif rank_method == "bm25+colpali":
285
+ result = await self.query_vespa_default(query, q_embs, sim_map=sim_map)
286
+ elif rank_method == "bm25":
287
+ result = await self.query_vespa_bm25(query, q_embs, sim_map=sim_map)
288
  else:
289
+ raise ValueError(f"Unsupported ranking: {rank_method}")
 
290
  # Print score, title id, and text of the results
291
  if "root" not in result or "children" not in result["root"]:
292
  result["root"] = {"children": []}
293
  return result
 
 
 
 
294
  for single_result in result["root"]["children"]:
295
  print(single_result["fields"].keys())
296
  return result
297
 
298
+ def get_sim_maps_from_query(
299
+ self, query: str, q_embs: torch.Tensor, ranking: str, token_to_idx: dict
300
+ ):
301
+ """
302
+ Get similarity maps from Vespa based on the ranking method.
303
+
304
+ Args:
305
+ query (str): The query text.
306
+ q_embs (torch.Tensor): Query embeddings.
307
+ ranking (str): The ranking method to use.
308
+ token_to_idx (dict): Token to index mapping.
309
+
310
+ Returns:
311
+ Dict[str, Any]: The query results.
312
+ """
313
+ # Get the result by calling asyncio.run
314
+ result = asyncio.run(
315
+ self.get_result_from_query(query, q_embs, ranking, token_to_idx)
316
+ )
317
+ vespa_sim_maps = []
318
+ for single_result in result["root"]["children"]:
319
+ vespa_sim_map = single_result["fields"].get("summaryfeatures", None)
320
+ if vespa_sim_map is not None:
321
+ vespa_sim_maps.append(vespa_sim_map)
322
+ else:
323
+ raise ValueError("No sim_map found in Vespa response")
324
+ return vespa_sim_maps
325
+
326
  async def get_full_image_from_vespa(self, doc_id: str) -> str:
327
  """
328
  Retrieve the full image from Vespa for a given document ID.
 
350
  )
351
  return response.json["root"]["children"][0]["fields"]["full_image"]
352
 
353
+ def get_results_children(self, result: VespaQueryResponse) -> list:
354
+ return result["root"]["children"]
355
+
356
+ def results_to_search_results(
357
+ self, result: VespaQueryResponse, token_to_idx: dict
358
+ ) -> list:
359
+ # Initialize sim_map_ fields in the result
360
+ fields_to_add = [
361
+ f"sim_map_{token}_{idx}"
362
+ for idx, token in enumerate(token_to_idx.keys())
363
+ if not is_special_token(token)
364
+ ]
365
+ for child in result["root"]["children"]:
366
+ for sim_map_key in fields_to_add:
367
+ child["fields"][sim_map_key] = None
368
+ return self.get_results_children(result)
369
+
370
  async def get_suggestions(self, query: str) -> list:
371
  async with self.app.asyncio(connections=1) as session:
372
  start = time.perf_counter()
 
398
  flat_questions = [item for sublist in questions for item in sublist]
399
  return flat_questions
400
 
401
+ def get_rank_profile(self, ranking: str, sim_map: bool) -> str:
402
+ if sim_map:
403
+ return f"{ranking}_sim"
404
+ else:
405
+ return ranking
406
+
407
  async def query_vespa_nearest_neighbor(
408
  self,
409
  query: str,
 
411
  target_hits_per_query_tensor: int = 20,
412
  hits: int = 3,
413
  timeout: str = "10s",
414
+ sim_map: bool = False,
415
  **kwargs,
416
  ) -> dict:
417
  """
 
442
  binary_query_embeddings, target_hits_per_query_tensor
443
  )
444
  query_tensors.update(nn_query_dict)
 
445
  response: VespaQueryResponse = await session.query(
446
  body={
447
  **query_tensors,
448
  "presentation.timing": True,
449
  "yql": (
450
+ f"select {self.get_fields(sim_map=sim_map)} from {self.VESPA_SCHEMA_NAME} where {nn_string} or userQuery()"
451
+ ),
452
+ "ranking.profile": self.get_rank_profile(
453
+ "retrieval-and-rerank", sim_map
454
  ),
 
455
  "timeout": timeout,
456
  "hits": hits,
457
  "query": query,
colpali-with-snippets/security/clients.pem CHANGED
@@ -1,9 +1,9 @@
1
  -----BEGIN CERTIFICATE-----
2
- MIIBNzCB36ADAgECAhEA7NcEr8GPdHvOjOU25NL76DAKBggqhkjOPQQDAjAeMRww
3
- GgYDVQQDExNjbG91ZC52ZXNwYS5leGFtcGxlMB4XDTI0MTAxMTA3MTk1OFoXDTM0
4
- MTAwOTA3MTk1OFowHjEcMBoGA1UEAxMTY2xvdWQudmVzcGEuZXhhbXBsZTBZMBMG
5
- ByqGSM49AgEGCCqGSM49AwEHA0IABF2UNuCudYzj+huI3Fl/uNDRYtbYYo2ex8Od
6
- Pij13s9gXQ7gEhHhDJuFPRfk0zBLr4sxQiDom8OHMvSiNbuhg2cwCgYIKoZIzj0E
7
- AwIDRwAwRAIgFcoy6ECnq07Hd9mYFEEETQRCr645aY47jwLbwUsEU9oCIARkTAyI
8
- cYe6Z1vAmA3IGPIn/gAPaSVVJiSe4QhzmqW/
9
  -----END CERTIFICATE-----
 
1
  -----BEGIN CERTIFICATE-----
2
+ MIIBODCB36ADAgECAhEAr37yU2TKTDMdQW1txTaMSjAKBggqhkjOPQQDAjAeMRww
3
+ GgYDVQQDExNjbG91ZC52ZXNwYS5leGFtcGxlMB4XDTI0MTAxNzA5NDY1M1oXDTM0
4
+ MTAxNTA5NDY1M1owHjEcMBoGA1UEAxMTY2xvdWQudmVzcGEuZXhhbXBsZTBZMBMG
5
+ ByqGSM49AgEGCCqGSM49AwEHA0IABPQjpb7RFvtnw288EY5eolq2v+0qC0h4JeW5
6
+ jCchXp4KUa5ufqeqyTcAxsfLn3BloPFDJ7Vb2gct9tZONa7xvc4wCgYIKoZIzj0E
7
+ AwIDSAAwRQIgR3wU3NUS02Behd0ojxo5sa5NVi0HhNW8RoAy0UyoGnACIQDWOqq+
8
+ zdKHJDorFuMWeMKKUe0cVQXZV3RvU5ssuXyEnw==
9
  -----END CERTIFICATE-----
frontend/app.py CHANGED
@@ -323,14 +323,13 @@ def SimMapButtonReady(query_id, idx, token, img_src):
323
  )
324
 
325
 
326
- def SimMapButtonPoll(query_id, idx, token):
327
  return Button(
328
  Lucide(icon="loader-circle", size="15", cls="animate-spin"),
329
  size="sm",
330
  disabled=True,
331
- hx_get=f"/get_sim_map?query_id={query_id}&idx={idx}&token={token}",
332
- # Poll every x seconds, where x is 0.3 x idx, formatted to 2 decimals
333
- hx_trigger=f"every {(idx+1)*0.3:.2f}s",
334
  hx_swap="outerHTML",
335
  cls="pointer-events-auto text-xs h-5 rounded-none px-2",
336
  )
@@ -352,7 +351,6 @@ def SearchResult(results: list, query_id: Optional[str] = None):
352
  fields = result["fields"] # Extract the 'fields' part of each result
353
  blur_image_base64 = f"data:image/jpeg;base64,{fields['blur_image']}"
354
 
355
- # Filter sim_map fields that are words with 4 or more characters
356
  sim_map_fields = {
357
  key: value
358
  for key, value in fields.items()
@@ -370,14 +368,17 @@ def SearchResult(results: list, query_id: Optional[str] = None):
370
  SimMapButtonReady(
371
  query_id=query_id,
372
  idx=idx,
373
- token=key.split("_")[-1],
374
  img_src=sim_map_base64,
375
  )
376
  )
377
  else:
378
  sim_map_buttons.append(
379
  SimMapButtonPoll(
380
- query_id=query_id, idx=idx, token=key.split("_")[-1]
 
 
 
381
  )
382
  )
383
 
 
323
  )
324
 
325
 
326
+ def SimMapButtonPoll(query_id, idx, token, token_idx):
327
  return Button(
328
  Lucide(icon="loader-circle", size="15", cls="animate-spin"),
329
  size="sm",
330
  disabled=True,
331
+ hx_get=f"/get_sim_map?query_id={query_id}&idx={idx}&token={token}&token_idx={token_idx}",
332
+ hx_trigger="every 0.5s",
 
333
  hx_swap="outerHTML",
334
  cls="pointer-events-auto text-xs h-5 rounded-none px-2",
335
  )
 
351
  fields = result["fields"] # Extract the 'fields' part of each result
352
  blur_image_base64 = f"data:image/jpeg;base64,{fields['blur_image']}"
353
 
 
354
  sim_map_fields = {
355
  key: value
356
  for key, value in fields.items()
 
368
  SimMapButtonReady(
369
  query_id=query_id,
370
  idx=idx,
371
+ token=key.split("_")[-2],
372
  img_src=sim_map_base64,
373
  )
374
  )
375
  else:
376
  sim_map_buttons.append(
377
  SimMapButtonPoll(
378
+ query_id=query_id,
379
+ idx=idx,
380
+ token=key.split("_")[-2],
381
+ token_idx=int(key.split("_")[-1]),
382
  )
383
  )
384
 
main.py CHANGED
@@ -1,26 +1,33 @@
1
  import asyncio
2
- import base64
3
- import io
4
  import os
5
  import time
6
- from concurrent.futures import ThreadPoolExecutor
7
- from functools import partial
8
  from pathlib import Path
 
9
  import uuid
10
- import hashlib
11
-
12
  import google.generativeai as genai
13
- from fasthtml.common import *
14
- from PIL import Image
15
- from shad4fast import *
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  from vespa.application import Vespa
 
 
 
17
 
18
- from backend.cache import LRUCache
19
- from backend.colpali import (
20
- add_sim_maps_to_result,
21
- get_query_embeddings_and_token_map,
22
- is_special_token,
23
- )
24
  from backend.modelmanager import ModelManager
25
  from backend.vespa_app import VespaQueryClient
26
  from frontend.app import (
@@ -77,10 +84,6 @@ app, rt = fast_app(
77
  ),
78
  )
79
  vespa_app: Vespa = VespaQueryClient()
80
- result_cache = LRUCache(max_size=20) # Each result can be ~10MB
81
- task_cache = LRUCache(
82
- max_size=1000
83
- ) # Map from query_id to boolean value - False if not all results are ready.
84
  thread_pool = ThreadPoolExecutor()
85
  # Gemini config
86
 
@@ -95,9 +98,11 @@ But, you should NOT include backticks (`) or HTML tags in your response.
95
  gemini_model = genai.GenerativeModel(
96
  "gemini-1.5-flash-8b", system_instruction=GEMINI_SYSTEM_PROMPT
97
  )
98
- STATIC_DIR = Path(__file__).parent / "static"
99
- IMG_DIR = STATIC_DIR / "saved"
100
- os.makedirs(STATIC_DIR, exist_ok=True)
 
 
101
 
102
 
103
  @app.on_event("startup")
@@ -112,9 +117,9 @@ async def keepalive():
112
  return
113
 
114
 
115
- def generate_query_id(session_id, query, ranking_value):
116
- hash_input = (session_id + query + ranking_value).encode("utf-8")
117
- return hashlib.sha256(hash_input).hexdigest()
118
 
119
 
120
  @rt("/static/{filepath:path}")
@@ -135,7 +140,7 @@ def get():
135
 
136
 
137
  @rt("/search")
138
- def get(session, request):
139
  # Extract the 'query' and 'ranking' parameters from the URL
140
  query_value = request.query_params.get("query", "").strip()
141
  ranking_value = request.query_params.get("ranking", "nn+colpali")
@@ -160,12 +165,7 @@ def get(session, request):
160
  )
161
  )
162
  # Generate a unique query_id based on the query and ranking value
163
- if "query_id" not in session:
164
- session["query_id"] = generate_query_id(
165
- session["session_id"], query_value, ranking_value
166
- )
167
- query_id = session.get("query_id")
168
- print(f"Query id in /search: {query_id}")
169
  # Show the loading message if a query is provided
170
  return Layout(
171
  Main(Search(request), data_overlayscrollbars_initialize=True, cls="border-t"),
@@ -176,23 +176,31 @@ def get(session, request):
176
  ) # Show SearchBox and Loading message initially
177
 
178
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
  @rt("/fetch_results")
180
- async def get(session, request, query: str, nn: bool = True):
181
  if "hx-request" not in request.headers:
182
  return RedirectResponse("/search")
183
 
184
- # Extract ranking option from the request
185
- ranking_value = request.query_params.get("ranking")
186
- print(
187
- f"/fetch_results: Fetching results for query: {query}, ranking: {ranking_value}"
188
- )
189
- # Generate a unique query_id based on the query and ranking value
190
- print(f"Sesssion in /fetch_results: {session}")
191
- if "query_id" not in session:
192
- session["query_id"] = generate_query_id(
193
- session["session_id"], query_value, ranking_value
194
- )
195
- query_id = session.get("query_id")
196
  print(f"Query id in /fetch_results: {query_id}")
197
  # Run the embedding and query against Vespa app
198
  model = app.manager.model
@@ -204,30 +212,21 @@ async def get(session, request, query: str, nn: bool = True):
204
  result = await vespa_app.get_result_from_query(
205
  query=query,
206
  q_embs=q_embs,
207
- ranking=ranking_value,
208
  token_to_idx=token_to_idx,
209
  )
210
  end = time.perf_counter()
211
  print(
212
  f"Search results fetched in {end - start:.2f} seconds, Vespa says searchtime was {result['timing']['searchtime']} seconds"
213
  )
214
- # Initialize sim_map_ fields in the result
215
- fields_to_add = [
216
- f"sim_map_{token}"
217
- for token in token_to_idx.keys()
218
- if not is_special_token(token)
219
- ]
220
- for child in result["root"]["children"]:
221
- for sim_map_key in fields_to_add:
222
- child["fields"][sim_map_key] = None
223
- result_cache.set(query_id, result)
224
- # Start generating the similarity map in the background
225
- asyncio.create_task(
226
- generate_similarity_map(
227
- model, processor, query, q_embs, token_to_idx, result, query_id
228
- )
229
  )
230
- search_results = get_results_children(result)
231
  return SearchResult(search_results, query_id)
232
 
233
 
@@ -247,78 +246,84 @@ async def poll_vespa_keepalive():
247
  print(f"Vespa keepalive: {time.time()}")
248
 
249
 
250
- async def generate_similarity_map(
251
- model, processor, query, q_embs, token_to_idx, result, query_id
252
- ):
253
- loop = asyncio.get_event_loop()
254
- sim_map_task = partial(
255
- add_sim_maps_to_result,
256
- result=result,
257
- model=model,
258
- processor=processor,
259
  query=query,
260
  q_embs=q_embs,
 
261
  token_to_idx=token_to_idx,
262
- query_id=query_id,
263
- result_cache=result_cache,
264
  )
265
- sim_map_result = await loop.run_in_executor(thread_pool, sim_map_task)
266
- result_cache.set(query_id, sim_map_result)
267
- task_cache.set(query_id, True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
268
 
269
 
270
  @app.get("/get_sim_map")
271
- async def get_sim_map(query_id: str, idx: int, token: str):
272
  """
273
  Endpoint that each of the sim map button polls to get the sim map image
274
  when it is ready. If it is not ready, returns a SimMapButtonPoll, that
275
  continues to poll every 1 second.
276
  """
277
- result = result_cache.get(query_id)
278
- if result is None:
279
- return SimMapButtonPoll(query_id=query_id, idx=idx, token=token)
280
- search_results = get_results_children(result)
281
- # Check if idx exists in list of children
282
- if idx >= len(search_results):
283
- return SimMapButtonPoll(query_id=query_id, idx=idx, token=token)
284
  else:
285
- sim_map_key = f"sim_map_{token}"
286
- sim_map_b64 = search_results[idx]["fields"].get(sim_map_key, None)
287
- if sim_map_b64 is None:
288
- return SimMapButtonPoll(query_id=query_id, idx=idx, token=token)
289
- sim_map_img_src = f"data:image/png;base64,{sim_map_b64}"
290
  return SimMapButtonReady(
291
- query_id=query_id, idx=idx, token=token, img_src=sim_map_img_src
292
  )
293
 
294
 
295
- async def update_full_image_cache(docid: str, query_id: str, idx: int, image_data: str):
296
- result = None
297
- max_wait = 20 # seconds. If horribly slow network latency.
298
- start_time = time.time()
299
- while result is None and time.time() - start_time < max_wait:
300
- result = result_cache.get(query_id)
301
- if result is None:
302
- await asyncio.sleep(0.1)
303
- try:
304
- result["root"]["children"][idx]["fields"]["full_image"] = image_data
305
- except KeyError as err:
306
- print(f"Error updating full image cache: {err}")
307
- result_cache.set(query_id, result)
308
- print(f"Full image cache updated for query_id {query_id}")
309
- return
310
-
311
-
312
  @app.get("/full_image")
313
  async def full_image(docid: str, query_id: str, idx: int):
314
  """
315
  Endpoint to get the full quality image for a given result id.
316
  """
317
- image_data = await vespa_app.get_full_image_from_vespa(docid)
318
- # Update the cache with the full image data
319
- asyncio.create_task(update_full_image_cache(docid, query_id, idx, image_data))
 
 
 
 
 
 
 
320
  return Img(
321
- src=f"data:image/png;base64,{image_data}",
322
  alt="something",
323
  cls="result-image w-full h-full object-contain",
324
  )
@@ -338,28 +343,25 @@ async def get_suggestions(request):
338
 
339
  async def message_generator(query_id: str, query: str):
340
  images = []
341
- result = None
342
- all_images_ready = False
343
  max_wait = 10 # seconds
344
  start_time = time.time()
345
- while not all_images_ready and time.time() - start_time < max_wait:
346
- result = result_cache.get(query_id)
347
- if result is None:
348
- await asyncio.sleep(0.1)
349
- continue
350
- search_results = get_results_children(result)
351
- for single_result in search_results:
352
- img = single_result["fields"].get("full_image", None)
353
- if img is not None:
354
- images.append(img)
355
- if len(images) == len(search_results):
356
- all_images_ready = True
357
- break
358
  else:
359
- await asyncio.sleep(0.1)
360
-
361
- # from b64 to PIL image
362
- images = [Image.open(io.BytesIO(base64.b64decode(img))) for img in images]
 
 
 
363
  if not images:
364
  yield "event: message\ndata: I am sorry, I do not have enough information in the image to answer your question.\n\n"
365
  yield "event: close\ndata: \n\n"
 
1
  import asyncio
 
 
2
  import os
3
  import time
 
 
4
  from pathlib import Path
5
+ from concurrent.futures import ThreadPoolExecutor
6
  import uuid
 
 
7
  import google.generativeai as genai
8
+ from fasthtml.common import (
9
+ Div,
10
+ Img,
11
+ Main,
12
+ P,
13
+ Script,
14
+ Link,
15
+ fast_app,
16
+ HighlightJS,
17
+ FileResponse,
18
+ RedirectResponse,
19
+ Aside,
20
+ StreamingResponse,
21
+ JSONResponse,
22
+ serve,
23
+ )
24
+ from shad4fast import ShadHead
25
  from vespa.application import Vespa
26
+ import base64
27
+ from fastcore.parallel import threaded
28
+ from PIL import Image
29
 
30
+ from backend.colpali import get_query_embeddings_and_token_map, gen_similarity_maps
 
 
 
 
 
31
  from backend.modelmanager import ModelManager
32
  from backend.vespa_app import VespaQueryClient
33
  from frontend.app import (
 
84
  ),
85
  )
86
  vespa_app: Vespa = VespaQueryClient()
 
 
 
 
87
  thread_pool = ThreadPoolExecutor()
88
  # Gemini config
89
 
 
98
  gemini_model = genai.GenerativeModel(
99
  "gemini-1.5-flash-8b", system_instruction=GEMINI_SYSTEM_PROMPT
100
  )
101
+ STATIC_DIR = Path("static")
102
+ IMG_DIR = STATIC_DIR / "full_images"
103
+ SIM_MAP_DIR = STATIC_DIR / "sim_maps"
104
+ os.makedirs(IMG_DIR, exist_ok=True)
105
+ os.makedirs(SIM_MAP_DIR, exist_ok=True)
106
 
107
 
108
  @app.on_event("startup")
 
117
  return
118
 
119
 
120
+ def generate_query_id(query, ranking_value):
121
+ hash_input = (query + ranking_value).encode("utf-8")
122
+ return hash(hash_input)
123
 
124
 
125
  @rt("/static/{filepath:path}")
 
140
 
141
 
142
  @rt("/search")
143
+ def get(request):
144
  # Extract the 'query' and 'ranking' parameters from the URL
145
  query_value = request.query_params.get("query", "").strip()
146
  ranking_value = request.query_params.get("ranking", "nn+colpali")
 
165
  )
166
  )
167
  # Generate a unique query_id based on the query and ranking value
168
+ query_id = generate_query_id(query_value, ranking_value)
 
 
 
 
 
169
  # Show the loading message if a query is provided
170
  return Layout(
171
  Main(Search(request), data_overlayscrollbars_initialize=True, cls="border-t"),
 
176
  ) # Show SearchBox and Loading message initially
177
 
178
 
179
+ @rt("/fetch_results2")
180
+ def get(query: str, ranking: str):
181
+ # 1. Get the results from Vespa (without sim_maps and full_images)
182
+ # Call search-endpoint in Vespa sync.
183
+
184
+ # 2. Kick off tasks to fetch sim_maps and full_images
185
+ # Sim maps - call search endpoint async.
186
+ # (A) New rank_profile that does not calculate sim_maps.
187
+ # (A) Make vespa endpoints take select_fields as a parameter.
188
+ # One sim map per image per token.
189
+ # the filename query_id_result_idx_token_idx.png
190
+ # Full image. based on the doc_id.
191
+ # Each of these tasks saves to disk.
192
+ # Need a cleanup task to delete old files.
193
+ # Polling endpoints for sim_maps and full_images checks if file exists and returns it.
194
+ pass
195
+
196
+
197
  @rt("/fetch_results")
198
+ async def get(session, request, query: str, ranking: str):
199
  if "hx-request" not in request.headers:
200
  return RedirectResponse("/search")
201
 
202
+ # Get the hash of the query and ranking value
203
+ query_id = generate_query_id(query, ranking)
 
 
 
 
 
 
 
 
 
 
204
  print(f"Query id in /fetch_results: {query_id}")
205
  # Run the embedding and query against Vespa app
206
  model = app.manager.model
 
212
  result = await vespa_app.get_result_from_query(
213
  query=query,
214
  q_embs=q_embs,
215
+ ranking=ranking,
216
  token_to_idx=token_to_idx,
217
  )
218
  end = time.perf_counter()
219
  print(
220
  f"Search results fetched in {end - start:.2f} seconds, Vespa says searchtime was {result['timing']['searchtime']} seconds"
221
  )
222
+ search_results = vespa_app.results_to_search_results(result, token_to_idx)
223
+ get_and_store_sim_maps(
224
+ query_id=query_id,
225
+ query=query,
226
+ q_embs=q_embs,
227
+ ranking=ranking,
228
+ token_to_idx=token_to_idx,
 
 
 
 
 
 
 
 
229
  )
 
230
  return SearchResult(search_results, query_id)
231
 
232
 
 
246
  print(f"Vespa keepalive: {time.time()}")
247
 
248
 
249
+ @threaded
250
+ def get_and_store_sim_maps(query_id, query: str, q_embs, ranking, token_to_idx):
251
+ ranking_sim = ranking + "_sim"
252
+ vespa_sim_maps = vespa_app.get_sim_maps_from_query(
 
 
 
 
 
253
  query=query,
254
  q_embs=q_embs,
255
+ ranking=ranking_sim,
256
  token_to_idx=token_to_idx,
 
 
257
  )
258
+ img_paths = [
259
+ IMG_DIR / f"{query_id}_{idx}.jpg" for idx in range(len(vespa_sim_maps))
260
+ ]
261
+ # All images should be downloaded, but best to wait 5 secs
262
+ max_wait = 5
263
+ start_time = time.time()
264
+ while (
265
+ not all([os.path.exists(img_path) for img_path in img_paths])
266
+ and time.time() - start_time < max_wait
267
+ ):
268
+ time.sleep(0.2)
269
+ if not all([os.path.exists(img_path) for img_path in img_paths]):
270
+ print(f"Images not ready in 5 seconds for query_id: {query_id}")
271
+ return False
272
+ sim_map_generator = gen_similarity_maps(
273
+ model=app.manager.model,
274
+ processor=app.manager.processor,
275
+ device=app.manager.device,
276
+ query=query,
277
+ query_embs=q_embs,
278
+ token_idx_map=token_to_idx,
279
+ images=img_paths,
280
+ vespa_sim_maps=vespa_sim_maps,
281
+ )
282
+ for idx, token, token_idx, blended_img_base64 in sim_map_generator:
283
+ with open(SIM_MAP_DIR / f"{query_id}_{idx}_{token_idx}.png", "wb") as f:
284
+ f.write(base64.b64decode(blended_img_base64))
285
+ print(
286
+ f"Sim map saved to disk for query_id: {query_id}, idx: {idx}, token: {token}"
287
+ )
288
+ return True
289
 
290
 
291
  @app.get("/get_sim_map")
292
+ async def get_sim_map(query_id: str, idx: int, token: str, token_idx: int):
293
  """
294
  Endpoint that each of the sim map button polls to get the sim map image
295
  when it is ready. If it is not ready, returns a SimMapButtonPoll, that
296
  continues to poll every 1 second.
297
  """
298
+ sim_map_path = SIM_MAP_DIR / f"{query_id}_{idx}_{token_idx}.png"
299
+ if not os.path.exists(sim_map_path):
300
+ print(f"Sim map not ready for query_id: {query_id}, idx: {idx}, token: {token}")
301
+ return SimMapButtonPoll(
302
+ query_id=query_id, idx=idx, token=token, token_idx=token_idx
303
+ )
 
304
  else:
 
 
 
 
 
305
  return SimMapButtonReady(
306
+ query_id=query_id, idx=idx, token=token, img_src=sim_map_path
307
  )
308
 
309
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
310
  @app.get("/full_image")
311
  async def full_image(docid: str, query_id: str, idx: int):
312
  """
313
  Endpoint to get the full quality image for a given result id.
314
  """
315
+ img_path = IMG_DIR / f"{query_id}_{idx}.jpg"
316
+ if not os.path.exists(img_path):
317
+ image_data = await vespa_app.get_full_image_from_vespa(docid)
318
+ # image data is base 64 encoded string. Save it to disk as jpg.
319
+ with open(img_path, "wb") as f:
320
+ f.write(base64.b64decode(image_data))
321
+ print(f"Full image saved to disk for query_id: {query_id}, idx: {idx}")
322
+ else:
323
+ with open(img_path, "rb") as f:
324
+ image_data = base64.b64encode(f.read()).decode("utf-8")
325
  return Img(
326
+ src=f"data:image/jpeg;base64,{image_data}",
327
  alt="something",
328
  cls="result-image w-full h-full object-contain",
329
  )
 
343
 
344
  async def message_generator(query_id: str, query: str):
345
  images = []
346
+ num_images = 3 # Number of images before firing chat request
 
347
  max_wait = 10 # seconds
348
  start_time = time.time()
349
+ # Check if full images are ready on disk
350
+ while len(images) < num_images and time.time() - start_time < max_wait:
351
+ for idx in range(num_images):
352
+ if not os.path.exists(IMG_DIR / f"{query_id}_{idx}.jpg"):
353
+ print(
354
+ f"Message generator: Full image not ready for query_id: {query_id}, idx: {idx}"
355
+ )
356
+ continue
 
 
 
 
 
357
  else:
358
+ print(
359
+ f"Message generator: image ready for query_id: {query_id}, idx: {idx}"
360
+ )
361
+ images.append(Image.open(IMG_DIR / f"{query_id}_{idx}.jpg"))
362
+ await asyncio.sleep(0.2)
363
+ # yield message with number of images ready
364
+ yield f"event: message\ndata: Generating response based on {len(images)} images.\n\n"
365
  if not images:
366
  yield "event: message\ndata: I am sorry, I do not have enough information in the image to answer your question.\n\n"
367
  yield "event: close\ndata: \n\n"