thomasht86 commited on
Commit
4775a9f
β€’
1 Parent(s): 8e4fbd2

Upload folder using huggingface_hub

Browse files
backend/colpali.py CHANGED
@@ -3,7 +3,7 @@
3
  import torch
4
  from PIL import Image
5
  import numpy as np
6
- from typing import cast
7
  from pathlib import Path
8
  import base64
9
  from io import BytesIO
@@ -119,7 +119,7 @@ def gen_similarity_maps(
119
  token_idx_map: dict,
120
  images: List[Union[Path, str]],
121
  vespa_sim_maps: List[str],
122
- ) -> List[Dict[str, str]]:
123
  """
124
  Generate similarity maps for the given images and query, and return base64-encoded blended images.
125
 
@@ -134,8 +134,9 @@ def gen_similarity_maps(
134
  images (List[Union[Path, str]]): List of image paths or base64-encoded strings.
135
  vespa_sim_maps (List[str]): List of Vespa similarity maps.
136
 
137
- Returns:
138
- List[Dict[str, str]]: A list where each item is a dictionary mapping tokens to base64-encoded blended images.
 
139
  """
140
 
141
  start = time.perf_counter()
@@ -302,11 +303,7 @@ def gen_similarity_maps(
302
 
303
  # Store the base64-encoded image
304
  result_per_image[token] = blended_img_base64
305
- results.append(result_per_image)
306
- end3 = time.perf_counter()
307
- print(f"Collecting blended images took: {end3 - start3} s")
308
- print(f"Total heatmap generation took: {end3 - start} s")
309
- return results
310
 
311
 
312
  def get_query_embeddings_and_token_map(
@@ -369,23 +366,32 @@ async def query_vespa_default(
369
  async def query_vespa_bm25(
370
  app: Vespa,
371
  query: str,
 
372
  hits: int = 3,
373
  timeout: str = "10s",
374
  **kwargs,
375
  ) -> dict:
376
  async with app.asyncio(connections=1, total_timeout=120) as session:
 
 
 
377
  response: VespaQueryResponse = await session.query(
378
  body={
379
- "yql": "select id,title,url,full_image,page_number,snippet,text from pdf_page where userQuery();",
380
  "ranking": "bm25",
381
  "query": query,
382
  "timeout": timeout,
383
  "hits": hits,
 
384
  "presentation.timing": True,
385
  **kwargs,
386
  },
387
  )
388
  assert response.is_successful(), response.json
 
 
 
 
389
  return format_query_results(query, response)
390
 
391
 
@@ -451,7 +457,7 @@ async def query_vespa_nearest_neighbor(
451
  **query_tensors,
452
  "presentation.timing": True,
453
  # if we use rank({nn_string}, userQuery()), dynamic summary doesn't work, see https://github.com/vespa-engine/vespa/issues/28704
454
- "yql": f"select id,title,snippet,text,url,full_image,page_number from pdf_page where {nn_string} or userQuery()",
455
  "ranking.profile": "retrieval-and-rerank",
456
  "timeout": timeout,
457
  "hits": hits,
@@ -489,7 +495,7 @@ async def get_result_from_query(
489
  elif ranking == "bm25+colpali":
490
  result = await query_vespa_default(app, query, q_embs)
491
  elif ranking == "bm25":
492
- result = await query_vespa_bm25(app, query)
493
  else:
494
  raise ValueError(f"Unsupported ranking: {ranking}")
495
  # Print score, title id, and text of the results
@@ -509,6 +515,8 @@ def add_sim_maps_to_result(
509
  query: str,
510
  q_embs: Any,
511
  token_to_idx: Dict[str, int],
 
 
512
  ) -> Dict[str, Any]:
513
  vit_config = load_vit_config(model)
514
  imgs: List[str] = []
@@ -520,7 +528,7 @@ def add_sim_maps_to_result(
520
  vespa_sim_map = single_result["fields"].get("summaryfeatures", None)
521
  if vespa_sim_map:
522
  vespa_sim_maps.append(vespa_sim_map)
523
- sim_map_imgs = gen_similarity_maps(
524
  model=model,
525
  processor=processor,
526
  device=model.device,
@@ -531,9 +539,14 @@ def add_sim_maps_to_result(
531
  images=imgs,
532
  vespa_sim_maps=vespa_sim_maps,
533
  )
534
- for single_result, sim_map_dict in zip(result["root"]["children"], sim_map_imgs):
535
- for token, sim_mapb64 in sim_map_dict.items():
536
- single_result["fields"][f"sim_map_{token}"] = sim_mapb64
 
 
 
 
 
537
  return result
538
 
539
 
 
3
  import torch
4
  from PIL import Image
5
  import numpy as np
6
+ from typing import cast, Generator
7
  from pathlib import Path
8
  import base64
9
  from io import BytesIO
 
119
  token_idx_map: dict,
120
  images: List[Union[Path, str]],
121
  vespa_sim_maps: List[str],
122
+ ) -> Generator[Tuple[int, str, str], None, None]:
123
  """
124
  Generate similarity maps for the given images and query, and return base64-encoded blended images.
125
 
 
134
  images (List[Union[Path, str]]): List of image paths or base64-encoded strings.
135
  vespa_sim_maps (List[str]): List of Vespa similarity maps.
136
 
137
+ Yields:
138
+ Tuple[int, str, str]: A tuple containing the image index, the selected token, and the base64-encoded image.
139
+
140
  """
141
 
142
  start = time.perf_counter()
 
303
 
304
  # Store the base64-encoded image
305
  result_per_image[token] = blended_img_base64
306
+ yield idx, token, blended_img_base64
 
 
 
 
307
 
308
 
309
  def get_query_embeddings_and_token_map(
 
366
  async def query_vespa_bm25(
367
  app: Vespa,
368
  query: str,
369
+ q_emb: torch.Tensor,
370
  hits: int = 3,
371
  timeout: str = "10s",
372
  **kwargs,
373
  ) -> dict:
374
  async with app.asyncio(connections=1, total_timeout=120) as session:
375
+ query_embedding = format_q_embs(q_emb)
376
+
377
+ start = time.perf_counter()
378
  response: VespaQueryResponse = await session.query(
379
  body={
380
+ "yql": "select id,title,url,full_image,page_number,snippet,text,summaryfeatures from pdf_page where userQuery();",
381
  "ranking": "bm25",
382
  "query": query,
383
  "timeout": timeout,
384
  "hits": hits,
385
+ "input.query(qt)": query_embedding,
386
  "presentation.timing": True,
387
  **kwargs,
388
  },
389
  )
390
  assert response.is_successful(), response.json
391
+ stop = time.perf_counter()
392
+ print(
393
+ f"Query time + data transfer took: {stop - start} s, vespa said searchtime was {response.json.get('timing', {}).get('searchtime', -1)} s"
394
+ )
395
  return format_query_results(query, response)
396
 
397
 
 
457
  **query_tensors,
458
  "presentation.timing": True,
459
  # if we use rank({nn_string}, userQuery()), dynamic summary doesn't work, see https://github.com/vespa-engine/vespa/issues/28704
460
+ "yql": f"select id,title,snippet,text,url,full_image,page_number,summaryfeatures from pdf_page where {nn_string} or userQuery()",
461
  "ranking.profile": "retrieval-and-rerank",
462
  "timeout": timeout,
463
  "hits": hits,
 
495
  elif ranking == "bm25+colpali":
496
  result = await query_vespa_default(app, query, q_embs)
497
  elif ranking == "bm25":
498
+ result = await query_vespa_bm25(app, query, q_embs)
499
  else:
500
  raise ValueError(f"Unsupported ranking: {ranking}")
501
  # Print score, title id, and text of the results
 
515
  query: str,
516
  q_embs: Any,
517
  token_to_idx: Dict[str, int],
518
+ query_id: str,
519
+ result_cache,
520
  ) -> Dict[str, Any]:
521
  vit_config = load_vit_config(model)
522
  imgs: List[str] = []
 
528
  vespa_sim_map = single_result["fields"].get("summaryfeatures", None)
529
  if vespa_sim_map:
530
  vespa_sim_maps.append(vespa_sim_map)
531
+ sim_map_imgs_generator = gen_similarity_maps(
532
  model=model,
533
  processor=processor,
534
  device=model.device,
 
539
  images=imgs,
540
  vespa_sim_maps=vespa_sim_maps,
541
  )
542
+ for img_idx, token, sim_mapb64 in sim_map_imgs_generator:
543
+ print(f"Created sim map for image {img_idx} and token {token}")
544
+ result["root"]["children"][img_idx]["fields"][f"sim_map_{token}"] = sim_mapb64
545
+ # Update result_cache with the new sim_map
546
+ result_cache.set(query_id, result)
547
+ # for single_result, sim_map_dict in zip(result["root"]["children"], sim_map_imgs):
548
+ # for token, sim_mapb64 in sim_map_dict.items():
549
+ # single_result["fields"][f"sim_map_{token}"] = sim_mapb64
550
  return result
551
 
552
 
colpali-with-snippets/schemas/pdf_page.sd CHANGED
@@ -67,15 +67,28 @@ schema pdf_page {
67
  }
68
 
69
  rank-profile bm25 {
 
 
 
 
 
70
  first-phase {
71
  expression: bm25(title) + bm25(text)
72
  }
 
 
 
 
 
 
 
 
 
73
  }
74
 
75
  rank-profile default {
76
  inputs {
77
- query(qt) tensor<float>(querytoken{}, v[128])
78
-
79
  }
80
  function max_sim() {
81
  expression {
@@ -92,13 +105,6 @@ schema pdf_page {
92
 
93
  }
94
  }
95
- function similarities() {
96
- expression {
97
- sum(
98
- query(qt) * unpack_bits(attribute(embedding)), v
99
- )
100
- }
101
- }
102
  function bm25_score() {
103
  expression {
104
  bm25(title) + bm25(text)
@@ -115,6 +121,13 @@ schema pdf_page {
115
  max_sim
116
  }
117
  }
 
 
 
 
 
 
 
118
  summary-features: similarities
119
  }
120
  rank-profile retrieval-and-rerank {
@@ -229,5 +242,13 @@ schema pdf_page {
229
  max_sim
230
  }
231
  }
 
 
 
 
 
 
 
 
232
  }
233
  }
 
67
  }
68
 
69
  rank-profile bm25 {
70
+
71
+ inputs {
72
+ query(qt) tensor<float>(querytoken{}, v[128]) # only used here to generate image similarity map
73
+ }
74
+
75
  first-phase {
76
  expression: bm25(title) + bm25(text)
77
  }
78
+
79
+ function similarities() {
80
+ expression {
81
+ sum(
82
+ query(qt) * unpack_bits(attribute(embedding)), v
83
+ )
84
+ }
85
+ }
86
+ summary-features: similarities
87
  }
88
 
89
  rank-profile default {
90
  inputs {
91
+ query(qt) tensor<float>(querytoken{}, v[128])
 
92
  }
93
  function max_sim() {
94
  expression {
 
105
 
106
  }
107
  }
 
 
 
 
 
 
 
108
  function bm25_score() {
109
  expression {
110
  bm25(title) + bm25(text)
 
121
  max_sim
122
  }
123
  }
124
+ function similarities() {
125
+ expression {
126
+ sum(
127
+ query(qt) * unpack_bits(attribute(embedding)), v
128
+ )
129
+ }
130
+ }
131
  summary-features: similarities
132
  }
133
  rank-profile retrieval-and-rerank {
 
242
  max_sim
243
  }
244
  }
245
+ function similarities() {
246
+ expression {
247
+ sum(
248
+ query(qt) * unpack_bits(attribute(embedding)), v
249
+ )
250
+ }
251
+ }
252
+ summary-features: similarities
253
  }
254
  }
frontend/app.py CHANGED
@@ -1,7 +1,7 @@
1
- from urllib.parse import quote_plus
2
  from typing import Optional
 
3
 
4
- from fasthtml.components import H1, H2, Div, Form, Img, P, Span, NotStr
5
  from fasthtml.xtend import A, Script
6
  from lucide_fasthtml import Lucide
7
  from shad4fast import Badge, Button, Input, Label, RadioGroup, RadioGroupItem
@@ -275,14 +275,22 @@ def SearchResult(results: list, query_id: Optional[str] = None):
275
  H2(fields["title"], cls="text-xl font-semibold"),
276
  P(
277
  "Page " + str(fields["page_number"]),
278
- cls="text-muted-foreground",
 
 
 
 
 
 
 
 
 
 
279
  ),
280
  P(
281
- "Relevance score: " + str(result["relevance"]),
282
- cls="text-muted-foreground",
283
  ),
284
- P(NotStr(fields["snippet"]), cls="text-muted-foreground"),
285
- P(NotStr(fields["text"]), cls="text-muted-foreground"),
286
  cls="text-sm grid gap-y-4",
287
  ),
288
  cls="bg-background px-3 py-5 hidden md:block",
 
 
1
  from typing import Optional
2
+ from urllib.parse import quote_plus
3
 
4
+ from fasthtml.components import H1, H2, Div, Form, Img, NotStr, P, Span
5
  from fasthtml.xtend import A, Script
6
  from lucide_fasthtml import Lucide
7
  from shad4fast import Badge, Button, Input, Label, RadioGroup, RadioGroupItem
 
275
  H2(fields["title"], cls="text-xl font-semibold"),
276
  P(
277
  "Page " + str(fields["page_number"]),
278
+ cls="text-foreground font-mono bold",
279
+ ),
280
+ Div(
281
+ Badge(
282
+ "Relevance score: " + str(result["relevance"]),
283
+ cls="flex gap-1.5 items-center justify-center",
284
+ ),
285
+ ),
286
+ P(
287
+ NotStr(fields.get("snippet", "")),
288
+ cls="text-highlight text-muted-foreground",
289
  ),
290
  P(
291
+ NotStr(fields.get("text", "")),
292
+ cls="text-highlight text-muted-foreground",
293
  ),
 
 
294
  cls="text-sm grid gap-y-4",
295
  ),
296
  cls="bg-background px-3 py-5 hidden md:block",
globals.css CHANGED
@@ -165,6 +165,11 @@
165
  }
166
  }
167
 
 
 
 
 
 
168
  .tokens-button {
169
  background-color: #B7E2F1;
170
  color: #2E2F27;
 
165
  }
166
  }
167
 
168
+ .text-highlight strong {
169
+ background-color: #61D790;
170
+ color: #2E2F27;
171
+ }
172
+
173
  .tokens-button {
174
  background-color: #B7E2F1;
175
  color: #2E2F27;
main.py CHANGED
@@ -40,6 +40,9 @@ app, rt = fast_app(
40
  vespa_app: Vespa = get_vespa_app()
41
 
42
  result_cache = LRUCache(max_size=20) # Each result can be ~10MB
 
 
 
43
  thread_pool = ThreadPoolExecutor()
44
 
45
 
@@ -97,7 +100,17 @@ async def get(request, query: str, nn: bool = True):
97
  )
98
  # Generate a unique query_id based on the query and ranking value
99
  query_id = generate_query_id(query + ranking_value)
100
-
 
 
 
 
 
 
 
 
 
 
101
  # Fetch model and processor
102
  manager = ModelManager.get_instance()
103
  model = manager.model
@@ -116,19 +129,26 @@ async def get(request, query: str, nn: bool = True):
116
  ranking=ranking_value,
117
  )
118
  end = time.perf_counter()
119
- print(f"Search results fetched in {end - start:.2f} seconds, Vespa says searchtime was {result['timing']['searchtime']} seconds")
 
 
120
  # Start generating the similarity map in the background
121
  asyncio.create_task(
122
  generate_similarity_map(
123
  model, processor, query, q_embs, token_to_idx, result, query_id
124
  )
125
  )
 
 
 
 
 
126
  search_results = (
127
  result["root"]["children"]
128
  if "root" in result and "children" in result["root"]
129
  else []
130
  )
131
- return SearchResult(search_results, query_id)
132
 
133
 
134
  async def generate_similarity_map(
@@ -143,22 +163,25 @@ async def generate_similarity_map(
143
  query=query,
144
  q_embs=q_embs,
145
  token_to_idx=token_to_idx,
 
 
146
  )
147
  sim_map_result = await loop.run_in_executor(thread_pool, sim_map_task)
148
  result_cache.set(query_id, sim_map_result)
 
149
 
150
 
151
  @app.get("/updated_search_results")
152
  async def updated_search_results(query_id: str):
153
- data = result_cache.get(query_id)
154
- if data is None:
155
  return HTMLResponse(status_code=204)
156
- search_results = (
157
- data["root"]["children"]
158
- if "root" in data and "children" in data["root"]
159
- else []
160
- )
161
- updated_content = SearchResult(results=search_results, query_id=None)
162
  return updated_content
163
 
164
 
 
40
  vespa_app: Vespa = get_vespa_app()
41
 
42
  result_cache = LRUCache(max_size=20) # Each result can be ~10MB
43
+ task_cache = LRUCache(
44
+ max_size=1000
45
+ ) # Map from query_id to boolean value - False if not all results are ready.
46
  thread_pool = ThreadPoolExecutor()
47
 
48
 
 
100
  )
101
  # Generate a unique query_id based on the query and ranking value
102
  query_id = generate_query_id(query + ranking_value)
103
+ # See if results are already in cache
104
+ if result_cache.get(query_id):
105
+ print(f"Results for query_id {query_id} already in cache")
106
+ result = result_cache.get(query_id)
107
+ search_results = get_results_children(result)
108
+ # If task is completed, return the results, but no query_id
109
+ if task_cache.get(query_id):
110
+ return SearchResult(search_results, None)
111
+ # If task is not completed, return the results with query_id
112
+ return SearchResult(search_results, query_id)
113
+ task_cache.set(query_id, False)
114
  # Fetch model and processor
115
  manager = ModelManager.get_instance()
116
  model = manager.model
 
129
  ranking=ranking_value,
130
  )
131
  end = time.perf_counter()
132
+ print(
133
+ f"Search results fetched in {end - start:.2f} seconds, Vespa says searchtime was {result['timing']['searchtime']} seconds"
134
+ )
135
  # Start generating the similarity map in the background
136
  asyncio.create_task(
137
  generate_similarity_map(
138
  model, processor, query, q_embs, token_to_idx, result, query_id
139
  )
140
  )
141
+ search_results = get_results_children(result)
142
+ return SearchResult(search_results, query_id)
143
+
144
+
145
+ def get_results_children(result):
146
  search_results = (
147
  result["root"]["children"]
148
  if "root" in result and "children" in result["root"]
149
  else []
150
  )
151
+ return search_results
152
 
153
 
154
  async def generate_similarity_map(
 
163
  query=query,
164
  q_embs=q_embs,
165
  token_to_idx=token_to_idx,
166
+ query_id=query_id,
167
+ result_cache=result_cache,
168
  )
169
  sim_map_result = await loop.run_in_executor(thread_pool, sim_map_task)
170
  result_cache.set(query_id, sim_map_result)
171
+ task_cache.set(query_id, True)
172
 
173
 
174
  @app.get("/updated_search_results")
175
  async def updated_search_results(query_id: str):
176
+ result = result_cache.get(query_id)
177
+ if result is None:
178
  return HTMLResponse(status_code=204)
179
+ search_results = get_results_children(result)
180
+ # Check if task is completed - Stop polling if it is
181
+ if task_cache.get(query_id):
182
+ updated_content = SearchResult(results=search_results, query_id=None)
183
+ else:
184
+ updated_content = SearchResult(results=search_results, query_id=query_id)
185
  return updated_content
186
 
187
 
output.css CHANGED
@@ -1117,6 +1117,10 @@ body {
1117
  justify-items: center;
1118
  }
1119
 
 
 
 
 
1120
  .gap-2 {
1121
  gap: 0.5rem;
1122
  }
@@ -1949,6 +1953,11 @@ body {
1949
  }
1950
  }
1951
 
 
 
 
 
 
1952
  .tokens-button {
1953
  background-color: #B7E2F1;
1954
  color: #2E2F27;
 
1117
  justify-items: center;
1118
  }
1119
 
1120
+ .gap-1\.5 {
1121
+ gap: 0.375rem;
1122
+ }
1123
+
1124
  .gap-2 {
1125
  gap: 0.5rem;
1126
  }
 
1953
  }
1954
  }
1955
 
1956
+ .text-highlight strong {
1957
+ background-color: #61D790;
1958
+ color: #2E2F27;
1959
+ }
1960
+
1961
  .tokens-button {
1962
  background-color: #B7E2F1;
1963
  color: #2E2F27;