Spaces:
Running
on
T4
Running
on
T4
thomasht86
commited on
Commit
β’
be59b6e
1
Parent(s):
2c5fb94
Upload folder using huggingface_hub
Browse files- README.md +7 -1
- backend/cache.py +26 -0
- backend/colpali.py +322 -290
- backend/modelmanager.py +22 -0
- colpali-with-snippets/schemas/pdf_page.sd +233 -0
- colpali-with-snippets/search/query-profiles/default.xml +2 -0
- colpali-with-snippets/search/query-profiles/types/root.xml +2 -0
- colpali-with-snippets/services.xml +43 -0
- colpalidemo/schemas/pdf_page.sd +7 -1
- deploy_vespa_app.py +6 -2
- feed_vespa.py +2 -0
- frontend/app.py +165 -47
- globals.css +14 -0
- icons.py +1 -1
- main.py +76 -56
- output.css +76 -43
- ruff.toml +77 -0
README.md
CHANGED
@@ -102,8 +102,14 @@ python feed_vespa.py --vespa_app_url https://myapp.z.vespa-app.cloud --vespa_clo
|
|
102 |
|
103 |
### Connecting to the Vespa app and querying
|
104 |
|
105 |
-
As a first step,
|
106 |
|
107 |
```bash
|
108 |
python query_vespa.py
|
109 |
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
|
103 |
### Connecting to the Vespa app and querying
|
104 |
|
105 |
+
As a first step, you can run the `query_vespa.py` script to run some sample queries against the Vespa app:
|
106 |
|
107 |
```bash
|
108 |
python query_vespa.py
|
109 |
```
|
110 |
+
|
111 |
+
### Starting the front-end
|
112 |
+
|
113 |
+
```bash
|
114 |
+
python main.py
|
115 |
+
```
|
backend/cache.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import OrderedDict
|
2 |
+
|
3 |
+
|
4 |
+
# Initialize LRU Cache
|
5 |
+
class LRUCache:
|
6 |
+
def __init__(self, max_size=20):
|
7 |
+
self.max_size = max_size
|
8 |
+
self.cache = OrderedDict()
|
9 |
+
|
10 |
+
def get(self, key):
|
11 |
+
if key in self.cache:
|
12 |
+
self.cache.move_to_end(key)
|
13 |
+
return self.cache[key]
|
14 |
+
return None
|
15 |
+
|
16 |
+
def set(self, key, value):
|
17 |
+
if key in self.cache:
|
18 |
+
self.cache.move_to_end(key)
|
19 |
+
else:
|
20 |
+
if len(self.cache) >= self.max_size:
|
21 |
+
self.cache.popitem(last=False)
|
22 |
+
self.cache[key] = value
|
23 |
+
|
24 |
+
def delete(self, key):
|
25 |
+
if key in self.cache:
|
26 |
+
del self.cache[key]
|
backend/colpali.py
CHANGED
@@ -4,45 +4,33 @@ import torch
|
|
4 |
from PIL import Image
|
5 |
import numpy as np
|
6 |
from typing import cast
|
7 |
-
import pprint
|
8 |
from pathlib import Path
|
9 |
import base64
|
10 |
from io import BytesIO
|
11 |
-
from typing import Union, Tuple
|
12 |
import matplotlib
|
|
|
13 |
import re
|
|
|
|
|
|
|
|
|
14 |
|
15 |
from colpali_engine.models import ColPali, ColPaliProcessor
|
16 |
from colpali_engine.utils.torch_utils import get_torch_device
|
17 |
from einops import rearrange
|
18 |
-
from vidore_benchmark.interpretability.plot_utils import plot_similarity_heatmap
|
19 |
from vidore_benchmark.interpretability.torch_utils import (
|
20 |
normalize_similarity_map_per_query_token,
|
21 |
)
|
22 |
from vidore_benchmark.interpretability.vit_configs import VIT_CONFIG
|
23 |
-
from vidore_benchmark.utils.image_utils import scale_image
|
24 |
from vespa.application import Vespa
|
25 |
from vespa.io import VespaQueryResponse
|
26 |
|
27 |
matplotlib.use("Agg")
|
28 |
|
29 |
MAX_QUERY_TERMS = 64
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
COLPALI_GEMMA_MODEL_ID = "vidore--colpaligemma-3b-pt-448-base"
|
34 |
-
COLPALI_GEMMA_MODEL_SNAPSHOT = "12c59eb7e23bc4c26876f7be7c17760d5d3a1ffa"
|
35 |
-
COLPALI_GEMMA_MODEL_PATH = (
|
36 |
-
Path().home()
|
37 |
-
/ f".cache/huggingface/hub/models--{COLPALI_GEMMA_MODEL_ID}/snapshots/{COLPALI_GEMMA_MODEL_SNAPSHOT}"
|
38 |
-
)
|
39 |
-
COLPALI_MODEL_ID = "vidore--colpali-v1.2"
|
40 |
-
COLPALI_MODEL_SNAPSHOT = "9912ce6f8a462d8cf2269f5606eabbd2784e764f"
|
41 |
-
COLPALI_MODEL_PATH = (
|
42 |
-
Path().home()
|
43 |
-
/ f".cache/huggingface/hub/models--{COLPALI_MODEL_ID}/snapshots/{COLPALI_MODEL_SNAPSHOT}"
|
44 |
-
)
|
45 |
-
COLPALI_GEMMA_MODEL_NAME = COLPALI_GEMMA_MODEL_ID.replace("--", "/")
|
46 |
|
47 |
|
48 |
def load_model() -> Tuple[ColPali, ColPaliProcessor]:
|
@@ -73,195 +61,241 @@ def load_vit_config(model):
|
|
73 |
return vit_config
|
74 |
|
75 |
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
try:
|
88 |
-
image = Image.open(image)
|
89 |
-
except Exception as e:
|
90 |
-
raise ValueError(f"Failed to open image from path: {e}")
|
91 |
-
elif isinstance(image, str):
|
92 |
-
# image is b64 string
|
93 |
-
try:
|
94 |
-
image = Image.open(BytesIO(base64.b64decode(image)))
|
95 |
-
except Exception as e:
|
96 |
-
raise ValueError(f"Failed to open image from b64: {e}")
|
97 |
-
|
98 |
-
# Preview the image
|
99 |
-
scale_image(image, 512)
|
100 |
-
# Preprocess inputs
|
101 |
-
input_text_processed = processor.process_queries([query]).to(device)
|
102 |
-
input_image_processed = processor.process_images([image]).to(device)
|
103 |
-
# Forward passes
|
104 |
-
with torch.no_grad():
|
105 |
-
output_text = model.forward(**input_text_processed)
|
106 |
-
output_image = model.forward(**input_image_processed)
|
107 |
-
# output_image is the tensor that we could get from the Vespa query
|
108 |
-
# Print shape of output_text and output_image
|
109 |
-
# Output image shape: torch.Size([1, 1030, 128])
|
110 |
-
# Remove the special tokens from the output
|
111 |
-
output_image = output_image[
|
112 |
-
:, : processor.image_seq_length, :
|
113 |
-
] # (1, n_patches_x * n_patches_y, dim)
|
114 |
-
|
115 |
-
# Rearrange the output image tensor to explicitly represent the 2D grid of patches
|
116 |
-
output_image = rearrange(
|
117 |
-
output_image,
|
118 |
-
"b (h w) c -> b h w c",
|
119 |
-
h=vit_config.n_patch_per_dim,
|
120 |
-
w=vit_config.n_patch_per_dim,
|
121 |
-
) # (1, n_patches_x, n_patches_y, dim)
|
122 |
-
# Get the similarity map
|
123 |
-
similarity_map = torch.einsum(
|
124 |
-
"bnk,bijk->bnij", output_text, output_image
|
125 |
-
) # (1, query_tokens, n_patches_x, n_patches_y)
|
126 |
-
|
127 |
-
# Normalize the similarity map
|
128 |
-
similarity_map_normalized = normalize_similarity_map_per_query_token(
|
129 |
-
similarity_map
|
130 |
-
) # (1, query_tokens, n_patches_x, n_patches_y)
|
131 |
-
# Use this cell output to choose a token using its index
|
132 |
-
query_tokens = processor.tokenizer.tokenize(
|
133 |
-
processor.decode(input_text_processed.input_ids[0])
|
134 |
-
)
|
135 |
-
# Choose a token
|
136 |
-
token_idx = (
|
137 |
-
10 # e.g. if "12: 'βKazakhstan',", set 12 to choose the token 'Kazakhstan'
|
138 |
-
)
|
139 |
-
selected_token = processor.decode(input_text_processed.input_ids[0, token_idx])
|
140 |
-
# strip whitespace
|
141 |
-
selected_token = selected_token.strip()
|
142 |
-
print(f"Selected token: `{selected_token}`")
|
143 |
-
# Retrieve the similarity map for the chosen token
|
144 |
-
pprint.pprint({idx: val for idx, val in enumerate(query_tokens)})
|
145 |
-
# Resize the image to square
|
146 |
-
input_image_square = image.resize((vit_config.resolution, vit_config.resolution))
|
147 |
-
|
148 |
-
# Plot the similarity map
|
149 |
-
fig, ax = plot_similarity_heatmap(
|
150 |
-
input_image_square,
|
151 |
-
patch_size=vit_config.patch_size,
|
152 |
-
image_resolution=vit_config.resolution,
|
153 |
-
similarity_map=similarity_map_normalized[0, token_idx, :, :],
|
154 |
-
)
|
155 |
-
ax = annotate_plot(ax, selected_token)
|
156 |
-
return fig, ax
|
157 |
-
|
158 |
-
|
159 |
-
# def save_figure(fig, filename: str = "similarity_map.png"):
|
160 |
-
# fig.savefig(
|
161 |
-
# OUTPUT_DIR / filename,
|
162 |
-
# bbox_inches="tight",
|
163 |
-
# pad_inches=0,
|
164 |
-
# )
|
165 |
|
166 |
|
167 |
def annotate_plot(ax, query, selected_token):
|
168 |
-
# Add the query text
|
169 |
-
ax.
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
ha="center",
|
176 |
va="center",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
177 |
fontsize=18,
|
178 |
-
color="
|
179 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
180 |
)
|
181 |
return ax
|
182 |
|
183 |
|
184 |
-
def
|
185 |
-
processor: ColPaliProcessor,
|
186 |
model: ColPali,
|
|
|
187 |
device,
|
188 |
vit_config,
|
189 |
query: str,
|
190 |
query_embs: torch.Tensor,
|
191 |
token_idx_map: dict,
|
192 |
-
|
193 |
-
|
194 |
-
):
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
#
|
215 |
-
|
216 |
-
|
217 |
-
#
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
261 |
|
262 |
|
263 |
def get_query_embeddings_and_token_map(
|
264 |
-
processor, model, query
|
265 |
) -> Tuple[torch.Tensor, dict]:
|
266 |
inputs = processor.process_queries([query]).to(model.device)
|
267 |
with torch.no_grad():
|
@@ -294,9 +328,11 @@ async def query_vespa_default(
|
|
294 |
) -> dict:
|
295 |
async with app.asyncio(connections=1, total_timeout=120) as session:
|
296 |
query_embedding = format_q_embs(q_emb)
|
|
|
|
|
297 |
response: VespaQueryResponse = await session.query(
|
298 |
body={
|
299 |
-
"yql": "select id,title,url,
|
300 |
"ranking": "default",
|
301 |
"query": query,
|
302 |
"timeout": timeout,
|
@@ -307,6 +343,32 @@ async def query_vespa_default(
|
|
307 |
},
|
308 |
)
|
309 |
assert response.is_successful(), response.json
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
310 |
return format_query_results(query, response)
|
311 |
|
312 |
|
@@ -371,10 +433,12 @@ async def query_vespa_nearest_neighbor(
|
|
371 |
body={
|
372 |
**query_tensors,
|
373 |
"presentation.timing": True,
|
374 |
-
|
|
|
375 |
"ranking.profile": "retrieval-and-rerank",
|
376 |
"timeout": timeout,
|
377 |
"hits": hits,
|
|
|
378 |
**kwargs,
|
379 |
},
|
380 |
)
|
@@ -383,8 +447,8 @@ async def query_vespa_nearest_neighbor(
|
|
383 |
|
384 |
|
385 |
def is_special_token(token: str) -> bool:
|
386 |
-
# Pattern for tokens that start with '<', numbers, whitespace, or single characters
|
387 |
-
pattern = re.compile(r"^<.*$|^\d+$|^\s
|
388 |
if pattern.match(token):
|
389 |
return True
|
390 |
return False
|
@@ -395,111 +459,64 @@ async def get_result_from_query(
|
|
395 |
processor: ColPaliProcessor,
|
396 |
model: ColPali,
|
397 |
query: str,
|
398 |
-
|
399 |
-
|
400 |
-
|
|
|
401 |
# Get the query embeddings and token map
|
402 |
print(query)
|
403 |
-
q_embs, token_to_idx = get_query_embeddings_and_token_map(
|
404 |
-
processor, model, query, dummy_image
|
405 |
-
)
|
406 |
-
print(token_to_idx)
|
407 |
-
# Use the token map to choose a token randomly for now
|
408 |
-
# Dynamically select a token containing 'water'
|
409 |
|
410 |
-
|
|
|
411 |
result = await query_vespa_nearest_neighbor(app, query, q_embs)
|
412 |
-
|
413 |
result = await query_vespa_default(app, query, q_embs)
|
414 |
-
|
|
|
|
|
|
|
|
|
415 |
for idx, child in enumerate(result["root"]["children"]):
|
416 |
print(
|
417 |
f"Result {idx+1}: {child['relevance']}, {child['fields']['title']}, {child['fields']['id']}"
|
418 |
)
|
419 |
-
|
420 |
-
|
421 |
-
for single_result in result["root"]["children"]:
|
422 |
-
img = single_result["fields"]["image"]
|
423 |
-
for token in token_to_idx:
|
424 |
-
if is_special_token(token):
|
425 |
-
print(f"Skipping special token: {token}")
|
426 |
-
continue
|
427 |
-
fig, ax = gen_similarity_map_new(
|
428 |
-
processor,
|
429 |
-
model,
|
430 |
-
model.device,
|
431 |
-
load_vit_config(model),
|
432 |
-
query,
|
433 |
-
q_embs,
|
434 |
-
token_to_idx,
|
435 |
-
token,
|
436 |
-
img,
|
437 |
-
)
|
438 |
-
sim_map = base64.b64encode(fig.canvas.tostring_rgb()).decode("utf-8")
|
439 |
-
single_result["fields"][f"sim_map_{token}"] = sim_map
|
440 |
return result
|
441 |
|
442 |
|
443 |
-
def
|
444 |
-
result
|
445 |
-
|
446 |
-
|
447 |
-
|
448 |
-
|
449 |
-
|
450 |
-
|
451 |
-
|
452 |
-
|
453 |
-
|
454 |
-
result["root"]["
|
455 |
-
|
456 |
-
|
457 |
-
|
458 |
-
|
459 |
-
|
460 |
-
|
461 |
-
|
462 |
-
|
463 |
-
|
464 |
-
|
465 |
-
|
466 |
-
|
467 |
-
|
468 |
-
|
469 |
-
|
470 |
-
|
471 |
-
elt0["fields"]["title"] = "ConocoPhillips 2023 Sustainability Report"
|
472 |
-
elt0["fields"]["page_number"] = 50
|
473 |
-
elt0["fields"]["image"] = "empty for now - is base64 encoded image"
|
474 |
-
result["root"]["children"].append(elt0)
|
475 |
-
elt1 = {}
|
476 |
-
elt1["id"] = "index:colpalidemo_content/0/b927c4979f0beaf0d7fab8e9"
|
477 |
-
elt1["relevance"] = 2313.7529950886965
|
478 |
-
elt1["source"] = "colpalidemo_content"
|
479 |
-
elt1["fields"] = {}
|
480 |
-
elt1["fields"]["id"] = "9f2fc0aa02c9561adfaa1451c875658f"
|
481 |
-
elt1["fields"]["url"] = (
|
482 |
-
"https://static.conocophillips.com/files/resources/conocophillips-2023-managing-climate-related-risks.pdf"
|
483 |
-
)
|
484 |
-
elt1["fields"]["title"] = "ConocoPhillips Managing Climate Related Risks"
|
485 |
-
elt1["fields"]["page_number"] = 44
|
486 |
-
elt1["fields"]["image"] = "empty for now - is base64 encoded image"
|
487 |
-
result["root"]["children"].append(elt1)
|
488 |
-
elt2 = {}
|
489 |
-
elt2["id"] = "index:colpalidemo_content/0/9632d72238829d6afefba6c9"
|
490 |
-
elt2["relevance"] = 2312.230182081461
|
491 |
-
elt2["source"] = "colpalidemo_content"
|
492 |
-
elt2["fields"] = {}
|
493 |
-
elt2["fields"]["id"] = "d638ded1ddcb446268b289b3f65430fd"
|
494 |
-
elt2["fields"]["url"] = (
|
495 |
-
"https://static.conocophillips.com/files/resources/24-0976-sustainability-highlights_nature.pdf"
|
496 |
-
)
|
497 |
-
elt2["fields"]["title"] = (
|
498 |
-
"ConocoPhillips Sustainability Highlights - Nature (24-0976)"
|
499 |
)
|
500 |
-
|
501 |
-
|
502 |
-
|
503 |
return result
|
504 |
|
505 |
|
@@ -513,9 +530,24 @@ if __name__ == "__main__":
|
|
513 |
/ "assets"
|
514 |
/ "ConocoPhillips Sustainability Highlights - Nature (24-0976).png"
|
515 |
)
|
516 |
-
|
517 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
518 |
)
|
519 |
-
|
520 |
-
|
|
|
|
|
521 |
print("Done")
|
|
|
4 |
from PIL import Image
|
5 |
import numpy as np
|
6 |
from typing import cast
|
|
|
7 |
from pathlib import Path
|
8 |
import base64
|
9 |
from io import BytesIO
|
10 |
+
from typing import Union, Tuple, List, Dict, Any
|
11 |
import matplotlib
|
12 |
+
import matplotlib.cm as cm
|
13 |
import re
|
14 |
+
import io
|
15 |
+
|
16 |
+
import json
|
17 |
+
import time
|
18 |
|
19 |
from colpali_engine.models import ColPali, ColPaliProcessor
|
20 |
from colpali_engine.utils.torch_utils import get_torch_device
|
21 |
from einops import rearrange
|
|
|
22 |
from vidore_benchmark.interpretability.torch_utils import (
|
23 |
normalize_similarity_map_per_query_token,
|
24 |
)
|
25 |
from vidore_benchmark.interpretability.vit_configs import VIT_CONFIG
|
|
|
26 |
from vespa.application import Vespa
|
27 |
from vespa.io import VespaQueryResponse
|
28 |
|
29 |
matplotlib.use("Agg")
|
30 |
|
31 |
MAX_QUERY_TERMS = 64
|
32 |
+
|
33 |
+
COLPALI_GEMMA_MODEL_NAME = "vidore/colpaligemma-3b-pt-448-base"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
|
35 |
|
36 |
def load_model() -> Tuple[ColPali, ColPaliProcessor]:
|
|
|
61 |
return vit_config
|
62 |
|
63 |
|
64 |
+
def save_figure(fig, filename: str = "similarity_map.png"):
|
65 |
+
try:
|
66 |
+
OUTPUT_DIR = Path(__file__).parent.parent / "output" / "sim_maps"
|
67 |
+
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
68 |
+
fig.savefig(
|
69 |
+
OUTPUT_DIR / filename,
|
70 |
+
bbox_inches="tight",
|
71 |
+
pad_inches=0,
|
72 |
+
)
|
73 |
+
except Exception as e:
|
74 |
+
print(f"Failed to save figure: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
|
76 |
|
77 |
def annotate_plot(ax, query, selected_token):
|
78 |
+
# Add the query text as a title over the image with opacity
|
79 |
+
ax.text(
|
80 |
+
0.5,
|
81 |
+
0.95, # Adjust the position to be on the image (y=0.1 is 10% from the bottom)
|
82 |
+
query,
|
83 |
+
fontsize=18,
|
84 |
+
color="white",
|
85 |
ha="center",
|
86 |
va="center",
|
87 |
+
alpha=0.8, # Set opacity (1 is fully opaque, 0 is fully transparent)
|
88 |
+
bbox=dict(
|
89 |
+
boxstyle="round,pad=0.5", fc="black", ec="none", lw=0, alpha=0.5
|
90 |
+
), # Add a semi-transparent background
|
91 |
+
transform=ax.transAxes, # Ensure the coordinates are relative to the axes
|
92 |
+
)
|
93 |
+
|
94 |
+
# Add annotation with the selected token over the image with opacity
|
95 |
+
ax.text(
|
96 |
+
0.5,
|
97 |
+
0.05, # Position towards the top of the image
|
98 |
+
f"Selected token: `{selected_token}`",
|
99 |
fontsize=18,
|
100 |
+
color="white",
|
101 |
+
ha="center",
|
102 |
+
va="center",
|
103 |
+
alpha=0.8, # Set opacity for the text
|
104 |
+
bbox=dict(
|
105 |
+
boxstyle="round,pad=0.3", fc="black", ec="none", lw=0, alpha=0.5
|
106 |
+
), # Semi-transparent background
|
107 |
+
transform=ax.transAxes, # Keep the coordinates relative to the axes
|
108 |
)
|
109 |
return ax
|
110 |
|
111 |
|
112 |
+
def gen_similarity_maps(
|
|
|
113 |
model: ColPali,
|
114 |
+
processor: ColPaliProcessor,
|
115 |
device,
|
116 |
vit_config,
|
117 |
query: str,
|
118 |
query_embs: torch.Tensor,
|
119 |
token_idx_map: dict,
|
120 |
+
images: List[Union[Path, str]],
|
121 |
+
vespa_sim_maps: List[str],
|
122 |
+
) -> List[Dict[str, str]]:
|
123 |
+
"""
|
124 |
+
Generate similarity maps for the given images and query, and return base64-encoded blended images.
|
125 |
+
|
126 |
+
Args:
|
127 |
+
model (ColPali): The model used for generating embeddings.
|
128 |
+
processor (ColPaliProcessor): Processor for images and text.
|
129 |
+
device: Device to run the computations on.
|
130 |
+
vit_config: Configuration for the Vision Transformer.
|
131 |
+
query (str): The query string.
|
132 |
+
query_embs (torch.Tensor): Query embeddings.
|
133 |
+
token_idx_map (dict): Mapping from tokens to their indices.
|
134 |
+
images (List[Union[Path, str]]): List of image paths or base64-encoded strings.
|
135 |
+
|
136 |
+
Returns:
|
137 |
+
List[Dict[str, str]]: A list where each item is a dictionary mapping tokens to base64-encoded blended images.
|
138 |
+
"""
|
139 |
+
|
140 |
+
start = time.perf_counter()
|
141 |
+
|
142 |
+
# Prepare the colormap once to avoid recomputation
|
143 |
+
colormap = cm.get_cmap("viridis")
|
144 |
+
|
145 |
+
# Process images and store original images and sizes
|
146 |
+
processed_images = []
|
147 |
+
original_images = []
|
148 |
+
original_sizes = []
|
149 |
+
for img in images:
|
150 |
+
if isinstance(img, Path):
|
151 |
+
try:
|
152 |
+
img_pil = Image.open(img).convert("RGB")
|
153 |
+
except Exception as e:
|
154 |
+
raise ValueError(f"Failed to open image from path: {e}")
|
155 |
+
elif isinstance(img, str):
|
156 |
+
try:
|
157 |
+
img_pil = Image.open(BytesIO(base64.b64decode(img))).convert("RGB")
|
158 |
+
except Exception as e:
|
159 |
+
raise ValueError(f"Failed to open image from base64 string: {e}")
|
160 |
+
else:
|
161 |
+
raise ValueError(f"Unsupported image type: {type(img)}")
|
162 |
+
original_images.append(img_pil.copy())
|
163 |
+
original_sizes.append(img_pil.size) # (width, height)
|
164 |
+
processed_images.append(img_pil)
|
165 |
+
|
166 |
+
# If similarity maps are provided, use them instead of computing them
|
167 |
+
if vespa_sim_maps:
|
168 |
+
print("Using provided similarity maps")
|
169 |
+
# A sim map looks like this:
|
170 |
+
# "similarities": [
|
171 |
+
# {
|
172 |
+
# "address": {
|
173 |
+
# "patch": "0",
|
174 |
+
# "querytoken": "0"
|
175 |
+
# },
|
176 |
+
# "value": 1.2599412202835083
|
177 |
+
# },
|
178 |
+
# ... and so on.
|
179 |
+
# Now turn these into a tensor of same shape as previous similarity map
|
180 |
+
vespa_sim_map_tensor = torch.zeros(
|
181 |
+
(len(vespa_sim_maps), query_embs.size(dim=1), vit_config.n_patch_per_dim, vit_config.n_patch_per_dim)
|
182 |
+
)
|
183 |
+
for idx, vespa_sim_map in enumerate(vespa_sim_maps):
|
184 |
+
for cell in vespa_sim_map["similarities"]["cells"]:
|
185 |
+
patch = int(cell["address"]["patch"])
|
186 |
+
if patch >= processor.image_seq_length:
|
187 |
+
continue
|
188 |
+
query_token = int(cell["address"]["querytoken"])
|
189 |
+
value = cell["value"]
|
190 |
+
vespa_sim_map_tensor[idx, int(query_token), int(patch) // vit_config.n_patch_per_dim, int(patch) % vit_config.n_patch_per_dim] = value
|
191 |
+
|
192 |
+
# Normalize the similarity map per query token
|
193 |
+
similarity_map_normalized = normalize_similarity_map_per_query_token(vespa_sim_map_tensor)
|
194 |
+
else:
|
195 |
+
# Preprocess inputs
|
196 |
+
print("Computing similarity maps")
|
197 |
+
start2 = time.perf_counter()
|
198 |
+
input_image_processed = processor.process_images(processed_images).to(device)
|
199 |
+
|
200 |
+
# Forward passes
|
201 |
+
with torch.no_grad():
|
202 |
+
output_image = model.forward(**input_image_processed)
|
203 |
+
|
204 |
+
# Remove the special tokens from the output
|
205 |
+
output_image = output_image[:, : processor.image_seq_length, :]
|
206 |
+
|
207 |
+
# Rearrange the output image tensor to represent the 2D grid of patches
|
208 |
+
output_image = rearrange(
|
209 |
+
output_image,
|
210 |
+
"b (h w) c -> b h w c",
|
211 |
+
h=vit_config.n_patch_per_dim,
|
212 |
+
w=vit_config.n_patch_per_dim,
|
213 |
+
)
|
214 |
+
|
215 |
+
# Ensure query_embs has batch dimension
|
216 |
+
if query_embs.dim() == 2:
|
217 |
+
query_embs = query_embs.unsqueeze(0).to(device)
|
218 |
+
else:
|
219 |
+
query_embs = query_embs.to(device)
|
220 |
+
|
221 |
+
# Compute the similarity map
|
222 |
+
similarity_map = torch.einsum(
|
223 |
+
"bnk,bhwk->bnhw", query_embs, output_image
|
224 |
+
) # Shape: (batch_size, query_tokens, h, w)
|
225 |
+
|
226 |
+
end2 = time.perf_counter()
|
227 |
+
print(f"Similarity map computation took: {end2 - start2} s")
|
228 |
+
|
229 |
+
# Normalize the similarity map per query token
|
230 |
+
similarity_map_normalized = normalize_similarity_map_per_query_token(similarity_map)
|
231 |
+
|
232 |
+
# Collect the blended images
|
233 |
+
start3 = time.perf_counter()
|
234 |
+
results = []
|
235 |
+
for idx, img in enumerate(original_images):
|
236 |
+
original_size = original_sizes[idx] # (width, height)
|
237 |
+
result_per_image = {}
|
238 |
+
for token, token_idx in token_idx_map.items():
|
239 |
+
if is_special_token(token):
|
240 |
+
continue
|
241 |
+
|
242 |
+
# Get the similarity map for this image and the selected token
|
243 |
+
sim_map = similarity_map_normalized[idx, token_idx, :, :] # Shape: (h, w)
|
244 |
+
|
245 |
+
# Move the similarity map to CPU and convert to NumPy array
|
246 |
+
sim_map_np = sim_map.cpu().numpy()
|
247 |
+
|
248 |
+
# Resize the similarity map to the original image size
|
249 |
+
sim_map_img = Image.fromarray(sim_map_np)
|
250 |
+
sim_map_resized = sim_map_img.resize(original_size, resample=Image.BICUBIC)
|
251 |
+
|
252 |
+
# Convert the resized similarity map to a NumPy array
|
253 |
+
sim_map_resized_np = np.array(sim_map_resized, dtype=np.float32)
|
254 |
+
|
255 |
+
# Normalize the similarity map to range [0, 1]
|
256 |
+
sim_map_min = sim_map_resized_np.min()
|
257 |
+
sim_map_max = sim_map_resized_np.max()
|
258 |
+
if sim_map_max - sim_map_min > 1e-6:
|
259 |
+
sim_map_normalized = (sim_map_resized_np - sim_map_min) / (
|
260 |
+
sim_map_max - sim_map_min
|
261 |
+
)
|
262 |
+
else:
|
263 |
+
sim_map_normalized = np.zeros_like(sim_map_resized_np)
|
264 |
+
|
265 |
+
# Apply a colormap to the normalized similarity map
|
266 |
+
heatmap = colormap(sim_map_normalized) # Returns an RGBA array
|
267 |
+
|
268 |
+
# Convert the heatmap to a PIL Image
|
269 |
+
heatmap_uint8 = (heatmap * 255).astype(np.uint8)
|
270 |
+
heatmap_img = Image.fromarray(heatmap_uint8)
|
271 |
+
|
272 |
+
# Ensure both images are in RGBA mode
|
273 |
+
original_img_rgba = img.convert("RGBA")
|
274 |
+
heatmap_img_rgba = heatmap_img.convert("RGBA")
|
275 |
+
|
276 |
+
# Overlay the heatmap onto the original image
|
277 |
+
blended_img = Image.blend(
|
278 |
+
original_img_rgba, heatmap_img_rgba, alpha=0.4
|
279 |
+
) # Adjust alpha as needed
|
280 |
+
# Save the blended image to a BytesIO buffer
|
281 |
+
buffer = io.BytesIO()
|
282 |
+
blended_img.save(buffer, format="PNG")
|
283 |
+
buffer.seek(0)
|
284 |
+
|
285 |
+
# Encode the image to base64
|
286 |
+
blended_img_base64 = base64.b64encode(buffer.read()).decode("utf-8")
|
287 |
+
|
288 |
+
# Store the base64-encoded image
|
289 |
+
result_per_image[token] = blended_img_base64
|
290 |
+
results.append(result_per_image)
|
291 |
+
end3 = time.perf_counter()
|
292 |
+
print(f"Collecting blended images took: {end3 - start3} s")
|
293 |
+
print(f"Total heatmap generation took: {end3 - start} s")
|
294 |
+
return results
|
295 |
|
296 |
|
297 |
def get_query_embeddings_and_token_map(
|
298 |
+
processor, model, query
|
299 |
) -> Tuple[torch.Tensor, dict]:
|
300 |
inputs = processor.process_queries([query]).to(model.device)
|
301 |
with torch.no_grad():
|
|
|
328 |
) -> dict:
|
329 |
async with app.asyncio(connections=1, total_timeout=120) as session:
|
330 |
query_embedding = format_q_embs(q_emb)
|
331 |
+
|
332 |
+
start = time.perf_counter()
|
333 |
response: VespaQueryResponse = await session.query(
|
334 |
body={
|
335 |
+
"yql": "select id,title,url,full_image,page_number,snippet,text,summaryfeatures from pdf_page where userQuery();",
|
336 |
"ranking": "default",
|
337 |
"query": query,
|
338 |
"timeout": timeout,
|
|
|
343 |
},
|
344 |
)
|
345 |
assert response.is_successful(), response.json
|
346 |
+
stop = time.perf_counter()
|
347 |
+
print(f"Query time + data transfer took: {stop - start} s, vespa said searchtime was {response.json.get('timing', {}).get('searchtime', -1)} s")
|
348 |
+
open("response.json", "w").write(json.dumps(response.json))
|
349 |
+
return format_query_results(query, response)
|
350 |
+
|
351 |
+
|
352 |
+
async def query_vespa_bm25(
|
353 |
+
app: Vespa,
|
354 |
+
query: str,
|
355 |
+
hits: int = 3,
|
356 |
+
timeout: str = "10s",
|
357 |
+
**kwargs,
|
358 |
+
) -> dict:
|
359 |
+
async with app.asyncio(connections=1, total_timeout=120) as session:
|
360 |
+
response: VespaQueryResponse = await session.query(
|
361 |
+
body={
|
362 |
+
"yql": "select id,title,url,full_image,page_number,snippet,text from pdf_page where userQuery();",
|
363 |
+
"ranking": "bm25",
|
364 |
+
"query": query,
|
365 |
+
"timeout": timeout,
|
366 |
+
"hits": hits,
|
367 |
+
"presentation.timing": True,
|
368 |
+
**kwargs,
|
369 |
+
},
|
370 |
+
)
|
371 |
+
assert response.is_successful(), response.json
|
372 |
return format_query_results(query, response)
|
373 |
|
374 |
|
|
|
433 |
body={
|
434 |
**query_tensors,
|
435 |
"presentation.timing": True,
|
436 |
+
# if we use rank({nn_string}, userQuery()), dynamic summary doesn't work, see https://github.com/vespa-engine/vespa/issues/28704
|
437 |
+
"yql": f"select id,title,snippet,text,url,full_image,page_number from pdf_page where {nn_string} or userQuery()",
|
438 |
"ranking.profile": "retrieval-and-rerank",
|
439 |
"timeout": timeout,
|
440 |
"hits": hits,
|
441 |
+
"query": query,
|
442 |
**kwargs,
|
443 |
},
|
444 |
)
|
|
|
447 |
|
448 |
|
449 |
def is_special_token(token: str) -> bool:
|
450 |
+
# Pattern for tokens that start with '<', numbers, whitespace, or single characters, or the string 'Question'
|
451 |
+
pattern = re.compile(r"^<.*$|^\d+$|^\s+$|^\w$|^Question$")
|
452 |
if pattern.match(token):
|
453 |
return True
|
454 |
return False
|
|
|
459 |
processor: ColPaliProcessor,
|
460 |
model: ColPali,
|
461 |
query: str,
|
462 |
+
q_embs: torch.Tensor,
|
463 |
+
token_to_idx: Dict[str, int],
|
464 |
+
ranking: str,
|
465 |
+
) -> Dict[str, Any]:
|
466 |
# Get the query embeddings and token map
|
467 |
print(query)
|
|
|
|
|
|
|
|
|
|
|
|
|
468 |
|
469 |
+
print(token_to_idx)
|
470 |
+
if ranking == "nn+colpali":
|
471 |
result = await query_vespa_nearest_neighbor(app, query, q_embs)
|
472 |
+
elif ranking == "bm25+colpali":
|
473 |
result = await query_vespa_default(app, query, q_embs)
|
474 |
+
elif ranking == "bm25":
|
475 |
+
result = await query_vespa_bm25(app, query)
|
476 |
+
else:
|
477 |
+
raise ValueError(f"Unsupported ranking: {ranking}")
|
478 |
+
# Print score, title id, and text of the results
|
479 |
for idx, child in enumerate(result["root"]["children"]):
|
480 |
print(
|
481 |
f"Result {idx+1}: {child['relevance']}, {child['fields']['title']}, {child['fields']['id']}"
|
482 |
)
|
483 |
+
for single_result in result["root"]["children"]:
|
484 |
+
print(single_result["fields"].keys())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
485 |
return result
|
486 |
|
487 |
|
488 |
+
def add_sim_maps_to_result(
|
489 |
+
result: Dict[str, Any],
|
490 |
+
model: ColPali,
|
491 |
+
processor: ColPaliProcessor,
|
492 |
+
query: str,
|
493 |
+
q_embs: Any,
|
494 |
+
token_to_idx: Dict[str, int],
|
495 |
+
) -> Dict[str, Any]:
|
496 |
+
vit_config = load_vit_config(model)
|
497 |
+
imgs: List[str] = []
|
498 |
+
vespa_sim_maps: List[str] = []
|
499 |
+
for single_result in result["root"]["children"]:
|
500 |
+
img = single_result["fields"]["full_image"]
|
501 |
+
if img:
|
502 |
+
imgs.append(img)
|
503 |
+
vespa_sim_map = single_result["fields"].get("summaryfeatures", None)
|
504 |
+
if vespa_sim_map:
|
505 |
+
vespa_sim_maps.append(vespa_sim_map)
|
506 |
+
sim_map_imgs = gen_similarity_maps(
|
507 |
+
model=model,
|
508 |
+
processor=processor,
|
509 |
+
device=model.device,
|
510 |
+
vit_config=vit_config,
|
511 |
+
query=query,
|
512 |
+
query_embs=q_embs,
|
513 |
+
token_idx_map=token_to_idx,
|
514 |
+
images=imgs,
|
515 |
+
vespa_sim_maps=vespa_sim_maps
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
516 |
)
|
517 |
+
for single_result, sim_map_dict in zip(result["root"]["children"], sim_map_imgs):
|
518 |
+
for token, sim_mapb64 in sim_map_dict.items():
|
519 |
+
single_result["fields"][f"sim_map_{token}"] = sim_mapb64
|
520 |
return result
|
521 |
|
522 |
|
|
|
530 |
/ "assets"
|
531 |
/ "ConocoPhillips Sustainability Highlights - Nature (24-0976).png"
|
532 |
)
|
533 |
+
q_embs, token_to_idx = get_query_embeddings_and_token_map(
|
534 |
+
processor,
|
535 |
+
model,
|
536 |
+
query,
|
537 |
+
)
|
538 |
+
figs_images = gen_similarity_maps(
|
539 |
+
model,
|
540 |
+
processor,
|
541 |
+
model.device,
|
542 |
+
vit_config,
|
543 |
+
query=query,
|
544 |
+
query_embs=q_embs,
|
545 |
+
token_idx_map=token_to_idx,
|
546 |
+
images=[image_filepath],
|
547 |
+
vespa_sim_maps=None,
|
548 |
)
|
549 |
+
for fig_token in figs_images:
|
550 |
+
for token, (fig, ax) in fig_token.items():
|
551 |
+
print(f"Token: {token}")
|
552 |
+
save_figure(fig, f"similarity_map_{token}.png")
|
553 |
print("Done")
|
backend/modelmanager.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .colpali import load_model
|
2 |
+
|
3 |
+
|
4 |
+
class ModelManager:
|
5 |
+
_instance = None
|
6 |
+
model = None
|
7 |
+
processor = None
|
8 |
+
|
9 |
+
@staticmethod
|
10 |
+
def get_instance():
|
11 |
+
if ModelManager._instance is None:
|
12 |
+
ModelManager._instance = ModelManager()
|
13 |
+
ModelManager._instance.initialize_model_and_processor()
|
14 |
+
return ModelManager._instance
|
15 |
+
|
16 |
+
def initialize_model_and_processor(self):
|
17 |
+
if self.model is None or self.processor is None: # Ensure no reinitialization
|
18 |
+
self.model, self.processor = load_model()
|
19 |
+
if self.model is None or self.processor is None:
|
20 |
+
print("Failed to initialize model or processor at startup")
|
21 |
+
else:
|
22 |
+
print("Model and processor loaded at startup")
|
colpali-with-snippets/schemas/pdf_page.sd
ADDED
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
schema pdf_page {
|
2 |
+
document pdf_page {
|
3 |
+
field id type string {
|
4 |
+
indexing: summary | index
|
5 |
+
match {
|
6 |
+
word
|
7 |
+
}
|
8 |
+
}
|
9 |
+
field url type string {
|
10 |
+
indexing: summary | index
|
11 |
+
}
|
12 |
+
field title type string {
|
13 |
+
indexing: summary | index
|
14 |
+
index: enable-bm25
|
15 |
+
match {
|
16 |
+
text
|
17 |
+
}
|
18 |
+
}
|
19 |
+
field page_number type int {
|
20 |
+
indexing: summary | attribute
|
21 |
+
}
|
22 |
+
field image type raw {
|
23 |
+
indexing: summary
|
24 |
+
}
|
25 |
+
field full_image type raw {
|
26 |
+
indexing: summary
|
27 |
+
}
|
28 |
+
field text type string {
|
29 |
+
indexing: summary | index
|
30 |
+
index: enable-bm25
|
31 |
+
match {
|
32 |
+
text
|
33 |
+
}
|
34 |
+
}
|
35 |
+
field embedding type tensor<int8>(patch{}, v[16]) {
|
36 |
+
indexing: attribute | index
|
37 |
+
attribute {
|
38 |
+
distance-metric: hamming
|
39 |
+
}
|
40 |
+
index {
|
41 |
+
hnsw {
|
42 |
+
max-links-per-node: 32
|
43 |
+
neighbors-to-explore-at-insert: 400
|
44 |
+
}
|
45 |
+
}
|
46 |
+
}
|
47 |
+
}
|
48 |
+
fieldset default {
|
49 |
+
fields: title, url, page_number, text
|
50 |
+
}
|
51 |
+
|
52 |
+
document-summary default {
|
53 |
+
from-disk
|
54 |
+
|
55 |
+
summary text {
|
56 |
+
bolding: on
|
57 |
+
}
|
58 |
+
|
59 |
+
summary snippet {
|
60 |
+
source: text
|
61 |
+
dynamic
|
62 |
+
}
|
63 |
+
}
|
64 |
+
|
65 |
+
fieldset image {
|
66 |
+
fields: image
|
67 |
+
}
|
68 |
+
|
69 |
+
rank-profile bm25 {
|
70 |
+
first-phase {
|
71 |
+
expression: bm25(title) + bm25(text)
|
72 |
+
}
|
73 |
+
}
|
74 |
+
|
75 |
+
rank-profile default {
|
76 |
+
inputs {
|
77 |
+
query(qt) tensor<float>(querytoken{}, v[128])
|
78 |
+
|
79 |
+
}
|
80 |
+
function max_sim() {
|
81 |
+
expression {
|
82 |
+
|
83 |
+
sum(
|
84 |
+
reduce(
|
85 |
+
sum(
|
86 |
+
query(qt) * unpack_bits(attribute(embedding)) , v
|
87 |
+
),
|
88 |
+
max, patch
|
89 |
+
),
|
90 |
+
querytoken
|
91 |
+
)
|
92 |
+
|
93 |
+
}
|
94 |
+
}
|
95 |
+
function similarities() {
|
96 |
+
expression {
|
97 |
+
sum(
|
98 |
+
query(qt) * unpack_bits(attribute(embedding)), v
|
99 |
+
)
|
100 |
+
}
|
101 |
+
}
|
102 |
+
function bm25_score() {
|
103 |
+
expression {
|
104 |
+
bm25(title) + bm25(text)
|
105 |
+
}
|
106 |
+
}
|
107 |
+
first-phase {
|
108 |
+
expression {
|
109 |
+
bm25_score
|
110 |
+
}
|
111 |
+
}
|
112 |
+
second-phase {
|
113 |
+
rerank-count: 10
|
114 |
+
expression {
|
115 |
+
max_sim
|
116 |
+
}
|
117 |
+
}
|
118 |
+
summary-features: similarities
|
119 |
+
}
|
120 |
+
rank-profile retrieval-and-rerank {
|
121 |
+
inputs {
|
122 |
+
query(rq0) tensor<int8>(v[16])
|
123 |
+
query(rq1) tensor<int8>(v[16])
|
124 |
+
query(rq2) tensor<int8>(v[16])
|
125 |
+
query(rq3) tensor<int8>(v[16])
|
126 |
+
query(rq4) tensor<int8>(v[16])
|
127 |
+
query(rq5) tensor<int8>(v[16])
|
128 |
+
query(rq6) tensor<int8>(v[16])
|
129 |
+
query(rq7) tensor<int8>(v[16])
|
130 |
+
query(rq8) tensor<int8>(v[16])
|
131 |
+
query(rq9) tensor<int8>(v[16])
|
132 |
+
query(rq10) tensor<int8>(v[16])
|
133 |
+
query(rq11) tensor<int8>(v[16])
|
134 |
+
query(rq12) tensor<int8>(v[16])
|
135 |
+
query(rq13) tensor<int8>(v[16])
|
136 |
+
query(rq14) tensor<int8>(v[16])
|
137 |
+
query(rq15) tensor<int8>(v[16])
|
138 |
+
query(rq16) tensor<int8>(v[16])
|
139 |
+
query(rq17) tensor<int8>(v[16])
|
140 |
+
query(rq18) tensor<int8>(v[16])
|
141 |
+
query(rq19) tensor<int8>(v[16])
|
142 |
+
query(rq20) tensor<int8>(v[16])
|
143 |
+
query(rq21) tensor<int8>(v[16])
|
144 |
+
query(rq22) tensor<int8>(v[16])
|
145 |
+
query(rq23) tensor<int8>(v[16])
|
146 |
+
query(rq24) tensor<int8>(v[16])
|
147 |
+
query(rq25) tensor<int8>(v[16])
|
148 |
+
query(rq26) tensor<int8>(v[16])
|
149 |
+
query(rq27) tensor<int8>(v[16])
|
150 |
+
query(rq28) tensor<int8>(v[16])
|
151 |
+
query(rq29) tensor<int8>(v[16])
|
152 |
+
query(rq30) tensor<int8>(v[16])
|
153 |
+
query(rq31) tensor<int8>(v[16])
|
154 |
+
query(rq32) tensor<int8>(v[16])
|
155 |
+
query(rq33) tensor<int8>(v[16])
|
156 |
+
query(rq34) tensor<int8>(v[16])
|
157 |
+
query(rq35) tensor<int8>(v[16])
|
158 |
+
query(rq36) tensor<int8>(v[16])
|
159 |
+
query(rq37) tensor<int8>(v[16])
|
160 |
+
query(rq38) tensor<int8>(v[16])
|
161 |
+
query(rq39) tensor<int8>(v[16])
|
162 |
+
query(rq40) tensor<int8>(v[16])
|
163 |
+
query(rq41) tensor<int8>(v[16])
|
164 |
+
query(rq42) tensor<int8>(v[16])
|
165 |
+
query(rq43) tensor<int8>(v[16])
|
166 |
+
query(rq44) tensor<int8>(v[16])
|
167 |
+
query(rq45) tensor<int8>(v[16])
|
168 |
+
query(rq46) tensor<int8>(v[16])
|
169 |
+
query(rq47) tensor<int8>(v[16])
|
170 |
+
query(rq48) tensor<int8>(v[16])
|
171 |
+
query(rq49) tensor<int8>(v[16])
|
172 |
+
query(rq50) tensor<int8>(v[16])
|
173 |
+
query(rq51) tensor<int8>(v[16])
|
174 |
+
query(rq52) tensor<int8>(v[16])
|
175 |
+
query(rq53) tensor<int8>(v[16])
|
176 |
+
query(rq54) tensor<int8>(v[16])
|
177 |
+
query(rq55) tensor<int8>(v[16])
|
178 |
+
query(rq56) tensor<int8>(v[16])
|
179 |
+
query(rq57) tensor<int8>(v[16])
|
180 |
+
query(rq58) tensor<int8>(v[16])
|
181 |
+
query(rq59) tensor<int8>(v[16])
|
182 |
+
query(rq60) tensor<int8>(v[16])
|
183 |
+
query(rq61) tensor<int8>(v[16])
|
184 |
+
query(rq62) tensor<int8>(v[16])
|
185 |
+
query(rq63) tensor<int8>(v[16])
|
186 |
+
query(qt) tensor<float>(querytoken{}, v[128])
|
187 |
+
query(qtb) tensor<int8>(querytoken{}, v[16])
|
188 |
+
|
189 |
+
}
|
190 |
+
function max_sim() {
|
191 |
+
expression {
|
192 |
+
|
193 |
+
sum(
|
194 |
+
reduce(
|
195 |
+
sum(
|
196 |
+
query(qt) * unpack_bits(attribute(embedding)) , v
|
197 |
+
),
|
198 |
+
max, patch
|
199 |
+
),
|
200 |
+
querytoken
|
201 |
+
)
|
202 |
+
|
203 |
+
}
|
204 |
+
}
|
205 |
+
function max_sim_binary() {
|
206 |
+
expression {
|
207 |
+
|
208 |
+
sum(
|
209 |
+
reduce(
|
210 |
+
1/(1 + sum(
|
211 |
+
hamming(query(qtb), attribute(embedding)) ,v)
|
212 |
+
),
|
213 |
+
max,
|
214 |
+
patch
|
215 |
+
),
|
216 |
+
querytoken
|
217 |
+
)
|
218 |
+
|
219 |
+
}
|
220 |
+
}
|
221 |
+
first-phase {
|
222 |
+
expression {
|
223 |
+
max_sim_binary
|
224 |
+
}
|
225 |
+
}
|
226 |
+
second-phase {
|
227 |
+
rerank-count: 10
|
228 |
+
expression {
|
229 |
+
max_sim
|
230 |
+
}
|
231 |
+
}
|
232 |
+
}
|
233 |
+
}
|
colpali-with-snippets/search/query-profiles/default.xml
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
<query-profile id="default" type="root">
|
2 |
+
</query-profile>
|
colpali-with-snippets/search/query-profiles/types/root.xml
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
<query-profile-type id="root">
|
2 |
+
</query-profile-type>
|
colpali-with-snippets/services.xml
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
2 |
+
<services version="1.0">
|
3 |
+
<container id="colpalidemo_container" version="1.0">
|
4 |
+
<search></search>
|
5 |
+
<document-api></document-api>
|
6 |
+
<document-processing></document-processing>
|
7 |
+
<clients>
|
8 |
+
<client id="mtls" permissions="read,write">
|
9 |
+
<certificate file="security/clients.pem" />
|
10 |
+
</client>
|
11 |
+
<client id="token_write" permissions="read,write">
|
12 |
+
<token id="colpalidemo_write" />
|
13 |
+
</client>
|
14 |
+
<client id="token_read" permissions="read">
|
15 |
+
<token id="colpalidemo_read" />
|
16 |
+
</client>
|
17 |
+
</clients>
|
18 |
+
<config name="container.qr-searchers">
|
19 |
+
<tag>
|
20 |
+
<bold>
|
21 |
+
<open><strong></open>
|
22 |
+
<close></strong></close>
|
23 |
+
</bold>
|
24 |
+
<separator>...</separator>
|
25 |
+
</tag>
|
26 |
+
</config>
|
27 |
+
</container>
|
28 |
+
<content id="colpalidemo_content" version="1.0">
|
29 |
+
<redundancy>1</redundancy>
|
30 |
+
<documents>
|
31 |
+
<document type="pdf_page" mode="index"></document>
|
32 |
+
</documents>
|
33 |
+
<nodes>
|
34 |
+
<node distribution-key="0" hostalias="node1"></node>
|
35 |
+
</nodes>
|
36 |
+
<config name="vespa.config.search.summary.juniperrc">
|
37 |
+
<max_matches>2</max_matches>
|
38 |
+
<length>1000</length>
|
39 |
+
<surround_max>500</surround_max>
|
40 |
+
<min_length>300</min_length>
|
41 |
+
</config>
|
42 |
+
</content>
|
43 |
+
</services>
|
colpalidemo/schemas/pdf_page.sd
CHANGED
@@ -22,6 +22,9 @@ schema pdf_page {
|
|
22 |
field image type raw {
|
23 |
indexing: summary
|
24 |
}
|
|
|
|
|
|
|
25 |
field text type string {
|
26 |
indexing: summary | index
|
27 |
index: enable-bm25
|
@@ -43,7 +46,10 @@ schema pdf_page {
|
|
43 |
}
|
44 |
}
|
45 |
fieldset default {
|
46 |
-
fields: title, text
|
|
|
|
|
|
|
47 |
}
|
48 |
rank-profile default {
|
49 |
inputs {
|
|
|
22 |
field image type raw {
|
23 |
indexing: summary
|
24 |
}
|
25 |
+
field full_image type raw {
|
26 |
+
indexing: summary
|
27 |
+
}
|
28 |
field text type string {
|
29 |
indexing: summary | index
|
30 |
index: enable-bm25
|
|
|
46 |
}
|
47 |
}
|
48 |
fieldset default {
|
49 |
+
fields: title, url, page_number, text
|
50 |
+
}
|
51 |
+
fieldset image {
|
52 |
+
fields: image
|
53 |
}
|
54 |
rank-profile default {
|
55 |
inputs {
|
deploy_vespa_app.py
CHANGED
@@ -16,6 +16,7 @@ from vespa.package import (
|
|
16 |
)
|
17 |
from vespa.deployment import VespaCloud
|
18 |
import os
|
|
|
19 |
|
20 |
|
21 |
def main():
|
@@ -60,6 +61,7 @@ def main():
|
|
60 |
name="page_number", type="int", indexing=["summary", "attribute"]
|
61 |
),
|
62 |
Field(name="image", type="raw", indexing=["summary"]),
|
|
|
63 |
Field(
|
64 |
name="text",
|
65 |
type="string",
|
@@ -190,10 +192,12 @@ def main():
|
|
190 |
tenant=tenant_name,
|
191 |
application=vespa_app_name,
|
192 |
key_content=vespa_team_api_key,
|
193 |
-
|
|
|
194 |
)
|
195 |
|
196 |
-
app = vespa_cloud.deploy()
|
|
|
197 |
|
198 |
# Output the endpoint URL
|
199 |
endpoint_url = vespa_cloud.get_token_endpoint()
|
|
|
16 |
)
|
17 |
from vespa.deployment import VespaCloud
|
18 |
import os
|
19 |
+
from pathlib import Path
|
20 |
|
21 |
|
22 |
def main():
|
|
|
61 |
name="page_number", type="int", indexing=["summary", "attribute"]
|
62 |
),
|
63 |
Field(name="image", type="raw", indexing=["summary"]),
|
64 |
+
Field(name="full_image", type="raw", indexing=["summary"]),
|
65 |
Field(
|
66 |
name="text",
|
67 |
type="string",
|
|
|
192 |
tenant=tenant_name,
|
193 |
application=vespa_app_name,
|
194 |
key_content=vespa_team_api_key,
|
195 |
+
application_root="colpali-with-snippets",
|
196 |
+
#application_package=vespa_application_package,
|
197 |
)
|
198 |
|
199 |
+
#app = vespa_cloud.deploy()
|
200 |
+
vespa_cloud.deploy_from_disk("default", "colpali-with-snippets")
|
201 |
|
202 |
# Output the endpoint URL
|
203 |
endpoint_url = vespa_cloud.get_token_endpoint()
|
feed_vespa.py
CHANGED
@@ -159,6 +159,7 @@ def main():
|
|
159 |
base_64_image = get_base64_image(
|
160 |
scale_image(image, 640), add_url_prefix=False
|
161 |
)
|
|
|
162 |
embedding_dict = dict()
|
163 |
for idx, patch_embedding in enumerate(embedding):
|
164 |
binary_vector = (
|
@@ -178,6 +179,7 @@ def main():
|
|
178 |
"title": title,
|
179 |
"page_number": page_number,
|
180 |
"image": base_64_image,
|
|
|
181 |
"text": page_text,
|
182 |
"embedding": embedding_dict,
|
183 |
},
|
|
|
159 |
base_64_image = get_base64_image(
|
160 |
scale_image(image, 640), add_url_prefix=False
|
161 |
)
|
162 |
+
base_64_full_image = get_base64_image(image, add_url_prefix=False)
|
163 |
embedding_dict = dict()
|
164 |
for idx, patch_embedding in enumerate(embedding):
|
165 |
binary_vector = (
|
|
|
179 |
"title": title,
|
180 |
"page_number": page_number,
|
181 |
"image": base_64_image,
|
182 |
+
"full_image": base_64_full_image,
|
183 |
"text": page_text,
|
184 |
"embedding": embedding_dict,
|
185 |
},
|
frontend/app.py
CHANGED
@@ -1,26 +1,59 @@
|
|
1 |
from urllib.parse import quote_plus
|
|
|
2 |
|
3 |
-
from fasthtml.components import
|
4 |
-
from fasthtml.xtend import
|
5 |
from lucide_fasthtml import Lucide
|
6 |
-
from shad4fast import Button, Input,
|
7 |
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
"""
|
12 |
window.onload = function() {
|
13 |
const input = document.getElementById('search-input');
|
14 |
const button = document.querySelector('[data-button="search-button"]');
|
15 |
-
|
16 |
-
|
17 |
-
checkInputValue()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
};
|
19 |
-
|
20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
-
|
|
|
24 |
grid_cls = "grid gap-2 items-center p-3 bg-muted/80 dark:bg-muted/40 w-full"
|
25 |
|
26 |
if with_border:
|
@@ -41,7 +74,30 @@ def SearchBox(with_border=False, query_value=""):
|
|
41 |
cls="relative",
|
42 |
),
|
43 |
Div(
|
44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
Button(
|
46 |
Lucide(icon="arrow-right", size="21"),
|
47 |
size="sm",
|
@@ -51,23 +107,23 @@ def SearchBox(with_border=False, query_value=""):
|
|
51 |
),
|
52 |
cls="flex justify-between",
|
53 |
),
|
54 |
-
check_input_script
|
55 |
-
action=f"/search?query={quote_plus(query_value)}",
|
56 |
method="GET",
|
57 |
-
hx_get=f"/fetch_results?query={quote_plus(query_value)}",
|
58 |
-
hx_trigger="load",
|
59 |
-
hx_target="#search-results",
|
60 |
-
hx_swap="outerHTML",
|
61 |
-
hx_indicator="#loading-indicator",
|
62 |
cls=grid_cls,
|
63 |
)
|
64 |
|
65 |
|
66 |
def SampleQueries():
|
67 |
sample_queries = [
|
68 |
-
"
|
69 |
-
"
|
70 |
-
"How
|
71 |
]
|
72 |
|
73 |
query_badges = []
|
@@ -83,7 +139,7 @@ def SampleQueries():
|
|
83 |
cls="flex gap-2 items-center",
|
84 |
),
|
85 |
variant="outline",
|
86 |
-
cls="text-base font-normal text-muted-foreground",
|
87 |
),
|
88 |
href=f"/search?query={quote_plus(query)}",
|
89 |
cls="no-underline",
|
@@ -96,7 +152,7 @@ def SampleQueries():
|
|
96 |
def Hero():
|
97 |
return Div(
|
98 |
H1(
|
99 |
-
"Vespa.
|
100 |
cls="text-5xl md:text-7xl font-bold tracking-wide md:tracking-wider bg-clip-text text-transparent bg-gradient-to-r from-black to-gray-700 dark:from-white dark:to-gray-300 animate-fade-in",
|
101 |
),
|
102 |
P(
|
@@ -121,12 +177,14 @@ def Home():
|
|
121 |
|
122 |
def Search(request, search_results=[]):
|
123 |
query_value = request.query_params.get("query", "").strip()
|
|
|
|
|
|
|
|
|
124 |
|
125 |
return Div(
|
126 |
Div(
|
127 |
-
SearchBox(
|
128 |
-
query_value=query_value
|
129 |
-
), # Pass the query value to pre-fill the SearchBox
|
130 |
Div(
|
131 |
LoadingMessage(), # Show the loading message initially
|
132 |
id="search-results", # This will be replaced by the search results
|
@@ -145,7 +203,7 @@ def LoadingMessage():
|
|
145 |
)
|
146 |
|
147 |
|
148 |
-
def SearchResult(results
|
149 |
if not results:
|
150 |
return Div(
|
151 |
P(
|
@@ -159,35 +217,95 @@ def SearchResult(results=[]):
|
|
159 |
result_items = []
|
160 |
for result in results:
|
161 |
fields = result["fields"] # Extract the 'fields' part of each result
|
162 |
-
|
163 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
164 |
)
|
165 |
-
# Print the fields that start with 'sim_map'
|
166 |
-
for key, value in fields.items():
|
167 |
-
if key.startswith("sim_map"):
|
168 |
-
print(f"{key}")
|
169 |
result_items.append(
|
170 |
Div(
|
171 |
Div(
|
172 |
-
|
173 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
174 |
),
|
175 |
Div(
|
176 |
Div(
|
177 |
H2(fields["title"], cls="text-xl font-semibold"),
|
178 |
P(
|
179 |
-
fields["
|
180 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
181 |
cls="text-sm grid gap-y-4",
|
182 |
),
|
183 |
-
cls="bg-background px-3 py-5",
|
184 |
),
|
185 |
-
cls="grid grid-cols-
|
186 |
)
|
187 |
)
|
188 |
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from urllib.parse import quote_plus
|
2 |
+
from typing import Optional
|
3 |
|
4 |
+
from fasthtml.components import H1, H2, Div, Form, Img, P, Span, NotStr
|
5 |
+
from fasthtml.xtend import A, Script
|
6 |
from lucide_fasthtml import Lucide
|
7 |
+
from shad4fast import Badge, Button, Input, Label, RadioGroup, RadioGroupItem
|
8 |
|
9 |
+
# JavaScript to check the input value and enable/disable the search button and radio buttons
|
10 |
+
check_input_script = Script(
|
11 |
+
"""
|
|
|
12 |
window.onload = function() {
|
13 |
const input = document.getElementById('search-input');
|
14 |
const button = document.querySelector('[data-button="search-button"]');
|
15 |
+
const radioGroupItems = document.querySelectorAll('button[data-ref="radio-item"]'); // Get all radio buttons
|
16 |
+
|
17 |
+
function checkInputValue() {
|
18 |
+
const isInputEmpty = input.value.trim() === "";
|
19 |
+
button.disabled = isInputEmpty; // Disable the submit button
|
20 |
+
radioGroupItems.forEach(item => {
|
21 |
+
item.disabled = isInputEmpty; // Disable/enable the radio buttons
|
22 |
+
});
|
23 |
+
}
|
24 |
+
|
25 |
+
input.addEventListener('input', checkInputValue); // Listen for input changes
|
26 |
+
checkInputValue(); // Initial check when the page loads
|
27 |
};
|
28 |
+
"""
|
29 |
+
)
|
30 |
+
|
31 |
+
# JavaScript to handle the image swapping, reset button, and active class toggling
|
32 |
+
image_swapping = Script(
|
33 |
+
"""
|
34 |
+
document.addEventListener('click', function (e) {
|
35 |
+
if (e.target.classList.contains('sim-map-button') || e.target.classList.contains('reset-button')) {
|
36 |
+
const newSrc = e.target.getAttribute('data-image-src');
|
37 |
+
const img = e.target.closest('.relative').querySelector('.result-image');
|
38 |
+
img.src = newSrc;
|
39 |
+
|
40 |
+
// Remove 'active' class from previously active button
|
41 |
+
const activeButton = document.querySelector('.sim-map-button.active');
|
42 |
+
if (activeButton) {
|
43 |
+
activeButton.classList.remove('active');
|
44 |
+
}
|
45 |
|
46 |
+
// Add 'active' class to the clicked button (if it's a sim-map button)
|
47 |
+
if (e.target.classList.contains('sim-map-button')) {
|
48 |
+
e.target.classList.add('active');
|
49 |
+
}
|
50 |
+
}
|
51 |
+
});
|
52 |
+
"""
|
53 |
+
)
|
54 |
|
55 |
+
|
56 |
+
def SearchBox(with_border=False, query_value="", ranking_value="nn+colpali"):
|
57 |
grid_cls = "grid gap-2 items-center p-3 bg-muted/80 dark:bg-muted/40 w-full"
|
58 |
|
59 |
if with_border:
|
|
|
74 |
cls="relative",
|
75 |
),
|
76 |
Div(
|
77 |
+
Div(
|
78 |
+
Span("Ranking by:", cls="text-muted-foreground text-xs font-semibold"),
|
79 |
+
RadioGroup(
|
80 |
+
Div(
|
81 |
+
RadioGroupItem(value="nn+colpali", id="nn+colpali"),
|
82 |
+
Label("nn+colpali", htmlFor="nn+colpali"),
|
83 |
+
cls="flex items-center space-x-2",
|
84 |
+
),
|
85 |
+
Div(
|
86 |
+
RadioGroupItem(value="bm25+colpali", id="bm25+colpali"),
|
87 |
+
Label("bm25+colpali", htmlFor="bm25+colpali"),
|
88 |
+
cls="flex items-center space-x-2",
|
89 |
+
),
|
90 |
+
Div(
|
91 |
+
RadioGroupItem(value="bm25", id="bm25"),
|
92 |
+
Label("bm25", htmlFor="bm25"),
|
93 |
+
cls="flex items-center space-x-2",
|
94 |
+
),
|
95 |
+
name="ranking",
|
96 |
+
default_value=ranking_value,
|
97 |
+
cls="grid-flow-col gap-x-5 text-muted-foreground",
|
98 |
+
),
|
99 |
+
cls="grid grid-flow-col items-center gap-x-3 border border-input px-3 rounded-sm",
|
100 |
+
),
|
101 |
Button(
|
102 |
Lucide(icon="arrow-right", size="21"),
|
103 |
size="sm",
|
|
|
107 |
),
|
108 |
cls="flex justify-between",
|
109 |
),
|
110 |
+
check_input_script,
|
111 |
+
action=f"/search?query={quote_plus(query_value)}&ranking={quote_plus(ranking_value)}",
|
112 |
method="GET",
|
113 |
+
hx_get=f"/fetch_results?query={quote_plus(query_value)}&ranking={quote_plus(ranking_value)}",
|
114 |
+
hx_trigger="load",
|
115 |
+
hx_target="#search-results",
|
116 |
+
hx_swap="outerHTML",
|
117 |
+
hx_indicator="#loading-indicator",
|
118 |
cls=grid_cls,
|
119 |
)
|
120 |
|
121 |
|
122 |
def SampleQueries():
|
123 |
sample_queries = [
|
124 |
+
"Percentage of non-fresh water as source?",
|
125 |
+
"Policies related to nature risk?",
|
126 |
+
"How much of produced water is recycled?",
|
127 |
]
|
128 |
|
129 |
query_badges = []
|
|
|
139 |
cls="flex gap-2 items-center",
|
140 |
),
|
141 |
variant="outline",
|
142 |
+
cls="text-base font-normal text-muted-foreground hover:border-black dark:hover:border-white",
|
143 |
),
|
144 |
href=f"/search?query={quote_plus(query)}",
|
145 |
cls="no-underline",
|
|
|
152 |
def Hero():
|
153 |
return Div(
|
154 |
H1(
|
155 |
+
"Vespa.ai + ColPali",
|
156 |
cls="text-5xl md:text-7xl font-bold tracking-wide md:tracking-wider bg-clip-text text-transparent bg-gradient-to-r from-black to-gray-700 dark:from-white dark:to-gray-300 animate-fade-in",
|
157 |
),
|
158 |
P(
|
|
|
177 |
|
178 |
def Search(request, search_results=[]):
|
179 |
query_value = request.query_params.get("query", "").strip()
|
180 |
+
ranking_value = request.query_params.get("ranking", "nn+colpali")
|
181 |
+
print(
|
182 |
+
f"Search: Fetching results for query: {query_value}, ranking: {ranking_value}"
|
183 |
+
)
|
184 |
|
185 |
return Div(
|
186 |
Div(
|
187 |
+
SearchBox(query_value=query_value, ranking_value=ranking_value),
|
|
|
|
|
188 |
Div(
|
189 |
LoadingMessage(), # Show the loading message initially
|
190 |
id="search-results", # This will be replaced by the search results
|
|
|
203 |
)
|
204 |
|
205 |
|
206 |
+
def SearchResult(results: list, query_id: Optional[str] = None):
|
207 |
if not results:
|
208 |
return Div(
|
209 |
P(
|
|
|
217 |
result_items = []
|
218 |
for result in results:
|
219 |
fields = result["fields"] # Extract the 'fields' part of each result
|
220 |
+
full_image_base64 = f"data:image/jpeg;base64,{fields['full_image']}"
|
221 |
+
|
222 |
+
# Filter sim_map fields that are words with 4 or more characters
|
223 |
+
sim_map_fields = {
|
224 |
+
key: value
|
225 |
+
for key, value in fields.items()
|
226 |
+
if key.startswith("sim_map_") and len(key.split("_")[-1]) >= 4
|
227 |
+
}
|
228 |
+
|
229 |
+
# Generate buttons for the sim_map fields
|
230 |
+
sim_map_buttons = []
|
231 |
+
for key, value in sim_map_fields.items():
|
232 |
+
sim_map_base64 = f"data:image/jpeg;base64,{value}"
|
233 |
+
sim_map_buttons.append(
|
234 |
+
Button(
|
235 |
+
key.split("_")[-1],
|
236 |
+
size="sm",
|
237 |
+
data_image_src=sim_map_base64,
|
238 |
+
cls="sim-map-button pointer-events-auto font-mono text-xs h-5 rounded-none px-2",
|
239 |
+
)
|
240 |
+
)
|
241 |
+
|
242 |
+
# Add "Reset Image" button to restore the full image
|
243 |
+
reset_button = Button(
|
244 |
+
"Reset",
|
245 |
+
variant="outline",
|
246 |
+
size="sm",
|
247 |
+
data_image_src=full_image_base64,
|
248 |
+
cls="reset-button pointer-events-auto font-mono text-xs h-5 rounded-none px-2",
|
249 |
+
)
|
250 |
+
# Add "Tokens" button - this has no action, just a placeholder
|
251 |
+
tokens_button = Button(
|
252 |
+
Lucide(icon="images", size="15"),
|
253 |
+
"Tokens",
|
254 |
+
size="sm",
|
255 |
+
cls="tokens-button flex gap-[3px] font-bold pointer-events-none font-mono text-xs h-5 rounded-none px-2",
|
256 |
)
|
|
|
|
|
|
|
|
|
257 |
result_items.append(
|
258 |
Div(
|
259 |
Div(
|
260 |
+
Div(
|
261 |
+
tokens_button,
|
262 |
+
*sim_map_buttons,
|
263 |
+
reset_button,
|
264 |
+
cls="flex flex-wrap gap-px w-full pointer-events-none",
|
265 |
+
),
|
266 |
+
Img(
|
267 |
+
src=full_image_base64,
|
268 |
+
alt=fields["title"],
|
269 |
+
cls="result-image max-w-full h-auto",
|
270 |
+
),
|
271 |
+
cls="relative grid gap-px content-start bg-background px-3 py-5",
|
272 |
),
|
273 |
Div(
|
274 |
Div(
|
275 |
H2(fields["title"], cls="text-xl font-semibold"),
|
276 |
P(
|
277 |
+
"Page " + str(fields["page_number"]),
|
278 |
+
cls="text-muted-foreground",
|
279 |
+
),
|
280 |
+
P(
|
281 |
+
"Relevance score: " + str(result["relevance"]),
|
282 |
+
cls="text-muted-foreground",
|
283 |
+
),
|
284 |
+
P(NotStr(fields["snippet"]), cls="text-muted-foreground"),
|
285 |
+
P(NotStr(fields["text"]), cls="text-muted-foreground"),
|
286 |
cls="text-sm grid gap-y-4",
|
287 |
),
|
288 |
+
cls="bg-background px-3 py-5 hidden md:block",
|
289 |
),
|
290 |
+
cls="grid grid-cols-1 md:grid-cols-2 col-span-2",
|
291 |
)
|
292 |
)
|
293 |
|
294 |
+
if query_id is not None:
|
295 |
+
return Div(
|
296 |
+
*result_items,
|
297 |
+
image_swapping,
|
298 |
+
hx_get=f"/updated_search_results?query_id={query_id}",
|
299 |
+
hx_trigger="every 1s",
|
300 |
+
hx_target="#search-results",
|
301 |
+
hx_swap="outerHTML",
|
302 |
+
id="search-results",
|
303 |
+
cls="grid grid-cols-2 gap-px bg-border",
|
304 |
+
)
|
305 |
+
else:
|
306 |
+
return Div(
|
307 |
+
*result_items,
|
308 |
+
image_swapping,
|
309 |
+
id="search-results",
|
310 |
+
cls="grid grid-cols-2 gap-px bg-border",
|
311 |
+
)
|
globals.css
CHANGED
@@ -155,3 +155,17 @@
|
|
155 |
.animate-slide-up {
|
156 |
animation: slide-up 1s ease-out forwards;
|
157 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
155 |
.animate-slide-up {
|
156 |
animation: slide-up 1s ease-out forwards;
|
157 |
}
|
158 |
+
|
159 |
+
.sim-map-button.active {
|
160 |
+
background-color: #61D790;
|
161 |
+
color: #2E2F27;
|
162 |
+
|
163 |
+
&:hover {
|
164 |
+
background-color: #61D790;
|
165 |
+
}
|
166 |
+
}
|
167 |
+
|
168 |
+
.tokens-button {
|
169 |
+
background-color: #B7E2F1;
|
170 |
+
color: #2E2F27;
|
171 |
+
}
|
icons.py
CHANGED
@@ -1 +1 @@
|
|
1 |
-
ICONS = {"chevrons-right": "<path d=\"m6 17 5-5-5-5\"></path><path d=\"m13 17 5-5-5-5\"></path>", "moon": "<path d=\"M12 3a6 6 0 0 0 9 9 9 9 0 1 1-9-9Z\"></path>", "sun": "<circle cx=\"12\" cy=\"12\" r=\"4\"></circle><path d=\"M12 2v2\"></path><path d=\"M12 20v2\"></path><path d=\"m4.93 4.93 1.41 1.41\"></path><path d=\"m17.66 17.66 1.41 1.41\"></path><path d=\"M2 12h2\"></path><path d=\"M20 12h2\"></path><path d=\"m6.34 17.66-1.41 1.41\"></path><path d=\"m19.07 4.93-1.41 1.41\"></path>", "github": "<path d=\"M15 22v-4a4.8 4.8 0 0 0-1-3.5c3 0 6-2 6-5.5.08-1.25-.27-2.48-1-3.5.28-1.15.28-2.35 0-3.5 0 0-1 0-3 1.5-2.64-.5-5.36-.5-8 0C6 2 5 2 5 2c-.3 1.15-.3 2.35 0 3.5A5.403 5.403 0 0 0 4 9c0 3.5 3 5.5 6 5.5-.39.49-.68 1.05-.85 1.65-.17.6-.22 1.23-.15 1.85v4\"></path><path d=\"M9 18c-4.51 2-5-2-7-2\"></path>", "slack": "<rect height=\"8\" rx=\"1.5\" width=\"3\" x=\"13\" y=\"2\"></rect><path d=\"M19 8.5V10h1.5A1.5 1.5 0 1 0 19 8.5\"></path><rect height=\"8\" rx=\"1.5\" width=\"3\" x=\"8\" y=\"14\"></rect><path d=\"M5 15.5V14H3.5A1.5 1.5 0 1 0 5 15.5\"></path><rect height=\"3\" rx=\"1.5\" width=\"8\" x=\"14\" y=\"13\"></rect><path d=\"M15.5 19H14v1.5a1.5 1.5 0 1 0 1.5-1.5\"></path><rect height=\"3\" rx=\"1.5\" width=\"8\" x=\"2\" y=\"8\"></rect><path d=\"M8.5 5H10V3.5A1.5 1.5 0 1 0 8.5 5\"></path>", "settings": "<path d=\"M12.22 2h-.44a2 2 0 0 0-2 2v.18a2 2 0 0 1-1 1.73l-.43.25a2 2 0 0 1-2 0l-.15-.08a2 2 0 0 0-2.73.73l-.22.38a2 2 0 0 0 .73 2.73l.15.1a2 2 0 0 1 1 1.72v.51a2 2 0 0 1-1 1.74l-.15.09a2 2 0 0 0-.73 2.73l.22.38a2 2 0 0 0 2.73.73l.15-.08a2 2 0 0 1 2 0l.43.25a2 2 0 0 1 1 1.73V20a2 2 0 0 0 2 2h.44a2 2 0 0 0 2-2v-.18a2 2 0 0 1 1-1.73l.43-.25a2 2 0 0 1 2 0l.15.08a2 2 0 0 0 2.73-.73l.22-.39a2 2 0 0 0-.73-2.73l-.15-.08a2 2 0 0 1-1-1.74v-.5a2 2 0 0 1 1-1.74l.15-.09a2 2 0 0 0 .73-2.73l-.22-.38a2 2 0 0 0-2.73-.73l-.15.08a2 2 0 0 1-2 0l-.43-.25a2 2 0 0 1-1-1.73V4a2 2 0 0 0-2-2z\"></path><circle cx=\"12\" cy=\"12\" r=\"3\"></circle>", "arrow-right": "<path d=\"M5 12h14\"></path><path d=\"m12 5 7 7-7 7\"></path>", "search": "<circle cx=\"11\" cy=\"11\" r=\"8\"></circle><path d=\"m21 21-4.3-4.3\"></path>", "file-search": "<path d=\"M14 2v4a2 2 0 0 0 2 2h4\"></path><path d=\"M4.268 21a2 2 0 0 0 1.727 1H18a2 2 0 0 0 2-2V7l-5-5H6a2 2 0 0 0-2 2v3\"></path><path d=\"m9 18-1.5-1.5\"></path><circle cx=\"5\" cy=\"14\" r=\"3\"></circle>", "message-circle-question": "<path d=\"M7.9 20A9 9 0 1 0 4 16.1L2 22Z\"></path><path d=\"M9.09 9a3 3 0 0 1 5.83 1c0 2-3 3-3 3\"></path><path d=\"M12 17h.01\"></path>", "text-search": "<path d=\"M21 6H3\"></path><path d=\"M10 12H3\"></path><path d=\"M10 18H3\"></path><circle cx=\"17\" cy=\"15\" r=\"3\"></circle><path d=\"m21 19-1.9-1.9\"></path>"}
|
|
|
1 |
+
ICONS = {"chevrons-right": "<path d=\"m6 17 5-5-5-5\"></path><path d=\"m13 17 5-5-5-5\"></path>", "moon": "<path d=\"M12 3a6 6 0 0 0 9 9 9 9 0 1 1-9-9Z\"></path>", "sun": "<circle cx=\"12\" cy=\"12\" r=\"4\"></circle><path d=\"M12 2v2\"></path><path d=\"M12 20v2\"></path><path d=\"m4.93 4.93 1.41 1.41\"></path><path d=\"m17.66 17.66 1.41 1.41\"></path><path d=\"M2 12h2\"></path><path d=\"M20 12h2\"></path><path d=\"m6.34 17.66-1.41 1.41\"></path><path d=\"m19.07 4.93-1.41 1.41\"></path>", "github": "<path d=\"M15 22v-4a4.8 4.8 0 0 0-1-3.5c3 0 6-2 6-5.5.08-1.25-.27-2.48-1-3.5.28-1.15.28-2.35 0-3.5 0 0-1 0-3 1.5-2.64-.5-5.36-.5-8 0C6 2 5 2 5 2c-.3 1.15-.3 2.35 0 3.5A5.403 5.403 0 0 0 4 9c0 3.5 3 5.5 6 5.5-.39.49-.68 1.05-.85 1.65-.17.6-.22 1.23-.15 1.85v4\"></path><path d=\"M9 18c-4.51 2-5-2-7-2\"></path>", "slack": "<rect height=\"8\" rx=\"1.5\" width=\"3\" x=\"13\" y=\"2\"></rect><path d=\"M19 8.5V10h1.5A1.5 1.5 0 1 0 19 8.5\"></path><rect height=\"8\" rx=\"1.5\" width=\"3\" x=\"8\" y=\"14\"></rect><path d=\"M5 15.5V14H3.5A1.5 1.5 0 1 0 5 15.5\"></path><rect height=\"3\" rx=\"1.5\" width=\"8\" x=\"14\" y=\"13\"></rect><path d=\"M15.5 19H14v1.5a1.5 1.5 0 1 0 1.5-1.5\"></path><rect height=\"3\" rx=\"1.5\" width=\"8\" x=\"2\" y=\"8\"></rect><path d=\"M8.5 5H10V3.5A1.5 1.5 0 1 0 8.5 5\"></path>", "settings": "<path d=\"M12.22 2h-.44a2 2 0 0 0-2 2v.18a2 2 0 0 1-1 1.73l-.43.25a2 2 0 0 1-2 0l-.15-.08a2 2 0 0 0-2.73.73l-.22.38a2 2 0 0 0 .73 2.73l.15.1a2 2 0 0 1 1 1.72v.51a2 2 0 0 1-1 1.74l-.15.09a2 2 0 0 0-.73 2.73l.22.38a2 2 0 0 0 2.73.73l.15-.08a2 2 0 0 1 2 0l.43.25a2 2 0 0 1 1 1.73V20a2 2 0 0 0 2 2h.44a2 2 0 0 0 2-2v-.18a2 2 0 0 1 1-1.73l.43-.25a2 2 0 0 1 2 0l.15.08a2 2 0 0 0 2.73-.73l.22-.39a2 2 0 0 0-.73-2.73l-.15-.08a2 2 0 0 1-1-1.74v-.5a2 2 0 0 1 1-1.74l.15-.09a2 2 0 0 0 .73-2.73l-.22-.38a2 2 0 0 0-2.73-.73l-.15.08a2 2 0 0 1-2 0l-.43-.25a2 2 0 0 1-1-1.73V4a2 2 0 0 0-2-2z\"></path><circle cx=\"12\" cy=\"12\" r=\"3\"></circle>", "arrow-right": "<path d=\"M5 12h14\"></path><path d=\"m12 5 7 7-7 7\"></path>", "search": "<circle cx=\"11\" cy=\"11\" r=\"8\"></circle><path d=\"m21 21-4.3-4.3\"></path>", "file-search": "<path d=\"M14 2v4a2 2 0 0 0 2 2h4\"></path><path d=\"M4.268 21a2 2 0 0 0 1.727 1H18a2 2 0 0 0 2-2V7l-5-5H6a2 2 0 0 0-2 2v3\"></path><path d=\"m9 18-1.5-1.5\"></path><circle cx=\"5\" cy=\"14\" r=\"3\"></circle>", "message-circle-question": "<path d=\"M7.9 20A9 9 0 1 0 4 16.1L2 22Z\"></path><path d=\"M9.09 9a3 3 0 0 1 5.83 1c0 2-3 3-3 3\"></path><path d=\"M12 17h.01\"></path>", "text-search": "<path d=\"M21 6H3\"></path><path d=\"M10 12H3\"></path><path d=\"M10 18H3\"></path><circle cx=\"17\" cy=\"15\" r=\"3\"></circle><path d=\"m21 19-1.9-1.9\"></path>", "maximize": "<path d=\"M8 3H5a2 2 0 0 0-2 2v3\"></path><path d=\"M21 8V5a2 2 0 0 0-2-2h-3\"></path><path d=\"M3 16v3a2 2 0 0 0 2 2h3\"></path><path d=\"M16 21h3a2 2 0 0 0 2-2v-3\"></path>", "expand": "<path d=\"m21 21-6-6m6 6v-4.8m0 4.8h-4.8\"></path><path d=\"M3 16.2V21m0 0h4.8M3 21l6-6\"></path><path d=\"M21 7.8V3m0 0h-4.8M21 3l-6 6\"></path><path d=\"M3 7.8V3m0 0h4.8M3 3l6 6\"></path>", "fullscreen": "<path d=\"M3 7V5a2 2 0 0 1 2-2h2\"></path><path d=\"M17 3h2a2 2 0 0 1 2 2v2\"></path><path d=\"M21 17v2a2 2 0 0 1-2 2h-2\"></path><path d=\"M7 21H5a2 2 0 0 1-2-2v-2\"></path><rect height=\"8\" rx=\"1\" width=\"10\" x=\"7\" y=\"8\"></rect>", "images": "<path d=\"M18 22H4a2 2 0 0 1-2-2V6\"></path><path d=\"m22 13-1.296-1.296a2.41 2.41 0 0 0-3.408 0L11 18\"></path><circle cx=\"12\" cy=\"8\" r=\"2\"></circle><rect height=\"16\" rx=\"2\" width=\"16\" x=\"6\" y=\"2\"></rect>", "circle": "<circle cx=\"12\" cy=\"12\" r=\"10\"></circle>"}
|
main.py
CHANGED
@@ -1,14 +1,23 @@
|
|
1 |
import asyncio
|
2 |
-
import
|
|
|
3 |
|
4 |
from fasthtml.common import *
|
5 |
from shad4fast import *
|
6 |
from vespa.application import Vespa
|
|
|
7 |
|
8 |
-
from backend.colpali import
|
|
|
|
|
|
|
|
|
9 |
from backend.vespa_app import get_vespa_app
|
10 |
-
from
|
|
|
|
|
11 |
from frontend.layout import Layout
|
|
|
12 |
|
13 |
highlight_js_theme_link = Link(id="highlight-theme", rel="stylesheet", href="")
|
14 |
highlight_js_theme = Script(src="/static/js/highlightjs-theme.js")
|
@@ -30,26 +39,12 @@ app, rt = fast_app(
|
|
30 |
)
|
31 |
vespa_app: Vespa = get_vespa_app()
|
32 |
|
|
|
|
|
33 |
|
34 |
-
class ModelManager:
|
35 |
-
_instance = None
|
36 |
-
model = None
|
37 |
-
processor = None
|
38 |
-
|
39 |
-
@staticmethod
|
40 |
-
def get_instance():
|
41 |
-
if ModelManager._instance is None:
|
42 |
-
ModelManager._instance = ModelManager()
|
43 |
-
ModelManager._instance.initialize_model_and_processor()
|
44 |
-
return ModelManager._instance
|
45 |
|
46 |
-
|
47 |
-
|
48 |
-
self.model, self.processor = load_model()
|
49 |
-
if self.model is None or self.processor is None:
|
50 |
-
print("Failed to initialize model or processor at startup")
|
51 |
-
else:
|
52 |
-
print("Model and processor loaded at startup")
|
53 |
|
54 |
|
55 |
@rt("/static/{filepath:path}")
|
@@ -64,15 +59,17 @@ def get():
|
|
64 |
|
65 |
@rt("/search")
|
66 |
def get(request):
|
67 |
-
# Extract the 'query'
|
68 |
query_value = request.query_params.get("query", "").strip()
|
|
|
|
|
69 |
|
70 |
# Always render the SearchBox first
|
71 |
if not query_value:
|
72 |
# Show SearchBox and a message for missing query
|
73 |
return Layout(
|
74 |
Div(
|
75 |
-
SearchBox(query_value=query_value),
|
76 |
Div(
|
77 |
P(
|
78 |
"No query provided. Please enter a query.",
|
@@ -89,39 +86,80 @@ def get(request):
|
|
89 |
|
90 |
|
91 |
@rt("/fetch_results")
|
92 |
-
def get(request, query: str, nn: bool = True):
|
93 |
-
# Check if the request came from HTMX; if not, redirect to /search
|
94 |
if "hx-request" not in request.headers:
|
95 |
return RedirectResponse("/search")
|
96 |
|
97 |
-
# Extract
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
|
99 |
# Fetch model and processor
|
100 |
manager = ModelManager.get_instance()
|
101 |
model = manager.model
|
102 |
processor = manager.processor
|
|
|
103 |
|
|
|
104 |
# Fetch real search results from Vespa
|
105 |
-
result =
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
113 |
)
|
114 |
)
|
115 |
-
|
116 |
-
# Extract search results from the result payload
|
117 |
search_results = (
|
118 |
result["root"]["children"]
|
119 |
if "root" in result and "children" in result["root"]
|
120 |
else []
|
121 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
122 |
|
123 |
-
|
124 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
125 |
|
126 |
|
127 |
@rt("/app")
|
@@ -129,24 +167,6 @@ def get():
|
|
129 |
return Layout(Div(P(f"Connected to Vespa at {vespa_app.url}"), cls="p-4"))
|
130 |
|
131 |
|
132 |
-
@rt("/run_query")
|
133 |
-
def get(query: str, nn: bool = False):
|
134 |
-
# dummy-function to avoid running the query every time
|
135 |
-
# result = get_result_dummy(query, nn)
|
136 |
-
# If we want to run real, uncomment the following lines
|
137 |
-
model, processor = get_model_and_processor()
|
138 |
-
result = asyncio.run(
|
139 |
-
get_result_from_query(
|
140 |
-
vespa_app, processor=processor, model=model, query=query, nn=nn
|
141 |
-
)
|
142 |
-
)
|
143 |
-
# model, processor = get_model_and_processor()
|
144 |
-
# result = asyncio.run(
|
145 |
-
# get_result_from_query(vespa_app, processor=processor, model=model, query=query, nn=nn)
|
146 |
-
# )
|
147 |
-
return Layout(Div(H1("Result"), Pre(Code(json.dumps(result, indent=2))), cls="p-4"))
|
148 |
-
|
149 |
-
|
150 |
if __name__ == "__main__":
|
151 |
# ModelManager.get_instance() # Initialize once at startup
|
152 |
serve(port=7860)
|
|
|
1 |
import asyncio
|
2 |
+
from concurrent.futures import ThreadPoolExecutor
|
3 |
+
from functools import partial
|
4 |
|
5 |
from fasthtml.common import *
|
6 |
from shad4fast import *
|
7 |
from vespa.application import Vespa
|
8 |
+
import time
|
9 |
|
10 |
+
from backend.colpali import (
|
11 |
+
get_result_from_query,
|
12 |
+
get_query_embeddings_and_token_map,
|
13 |
+
add_sim_maps_to_result,
|
14 |
+
)
|
15 |
from backend.vespa_app import get_vespa_app
|
16 |
+
from backend.cache import LRUCache
|
17 |
+
from backend.modelmanager import ModelManager
|
18 |
+
from frontend.app import Home, Search, SearchBox, SearchResult
|
19 |
from frontend.layout import Layout
|
20 |
+
import hashlib
|
21 |
|
22 |
highlight_js_theme_link = Link(id="highlight-theme", rel="stylesheet", href="")
|
23 |
highlight_js_theme = Script(src="/static/js/highlightjs-theme.js")
|
|
|
39 |
)
|
40 |
vespa_app: Vespa = get_vespa_app()
|
41 |
|
42 |
+
result_cache = LRUCache(max_size=20) # Each result can be ~10MB
|
43 |
+
thread_pool = ThreadPoolExecutor()
|
44 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
|
46 |
+
def generate_query_id(query):
|
47 |
+
return hashlib.md5(query.encode("utf-8")).hexdigest()
|
|
|
|
|
|
|
|
|
|
|
48 |
|
49 |
|
50 |
@rt("/static/{filepath:path}")
|
|
|
59 |
|
60 |
@rt("/search")
|
61 |
def get(request):
|
62 |
+
# Extract the 'query' and 'ranking' parameters from the URL
|
63 |
query_value = request.query_params.get("query", "").strip()
|
64 |
+
ranking_value = request.query_params.get("ranking", "nn+colpali")
|
65 |
+
print("/search: Fetching results for ranking_value:", ranking_value)
|
66 |
|
67 |
# Always render the SearchBox first
|
68 |
if not query_value:
|
69 |
# Show SearchBox and a message for missing query
|
70 |
return Layout(
|
71 |
Div(
|
72 |
+
SearchBox(query_value=query_value, ranking_value=ranking_value),
|
73 |
Div(
|
74 |
P(
|
75 |
"No query provided. Please enter a query.",
|
|
|
86 |
|
87 |
|
88 |
@rt("/fetch_results")
|
89 |
+
async def get(request, query: str, nn: bool = True):
|
|
|
90 |
if "hx-request" not in request.headers:
|
91 |
return RedirectResponse("/search")
|
92 |
|
93 |
+
# Extract ranking option from the request
|
94 |
+
ranking_value = request.query_params.get("ranking")
|
95 |
+
print(
|
96 |
+
f"/fetch_results: Fetching results for query: {query}, ranking: {ranking_value}"
|
97 |
+
)
|
98 |
+
# Generate a unique query_id based on the query and ranking value
|
99 |
+
query_id = generate_query_id(query + ranking_value)
|
100 |
|
101 |
# Fetch model and processor
|
102 |
manager = ModelManager.get_instance()
|
103 |
model = manager.model
|
104 |
processor = manager.processor
|
105 |
+
q_embs, token_to_idx = get_query_embeddings_and_token_map(processor, model, query)
|
106 |
|
107 |
+
start = time.perf_counter()
|
108 |
# Fetch real search results from Vespa
|
109 |
+
result = await get_result_from_query(
|
110 |
+
app=vespa_app,
|
111 |
+
processor=processor,
|
112 |
+
model=model,
|
113 |
+
query=query,
|
114 |
+
q_embs=q_embs,
|
115 |
+
token_to_idx=token_to_idx,
|
116 |
+
ranking=ranking_value,
|
117 |
+
)
|
118 |
+
end = time.perf_counter()
|
119 |
+
print(f"Search results fetched in {end - start:.2f} seconds, Vespa says searchtime was {result['timing']['searchtime']} seconds")
|
120 |
+
# Start generating the similarity map in the background
|
121 |
+
asyncio.create_task(
|
122 |
+
generate_similarity_map(
|
123 |
+
model, processor, query, q_embs, token_to_idx, result, query_id
|
124 |
)
|
125 |
)
|
|
|
|
|
126 |
search_results = (
|
127 |
result["root"]["children"]
|
128 |
if "root" in result and "children" in result["root"]
|
129 |
else []
|
130 |
)
|
131 |
+
return SearchResult(search_results, query_id)
|
132 |
+
|
133 |
+
|
134 |
+
async def generate_similarity_map(
|
135 |
+
model, processor, query, q_embs, token_to_idx, result, query_id
|
136 |
+
):
|
137 |
+
loop = asyncio.get_event_loop()
|
138 |
+
sim_map_task = partial(
|
139 |
+
add_sim_maps_to_result,
|
140 |
+
result=result,
|
141 |
+
model=model,
|
142 |
+
processor=processor,
|
143 |
+
query=query,
|
144 |
+
q_embs=q_embs,
|
145 |
+
token_to_idx=token_to_idx,
|
146 |
+
)
|
147 |
+
sim_map_result = await loop.run_in_executor(thread_pool, sim_map_task)
|
148 |
+
result_cache.set(query_id, sim_map_result)
|
149 |
+
|
150 |
|
151 |
+
@app.get("/updated_search_results")
|
152 |
+
async def updated_search_results(query_id: str):
|
153 |
+
data = result_cache.get(query_id)
|
154 |
+
if data is None:
|
155 |
+
return HTMLResponse(status_code=204)
|
156 |
+
search_results = (
|
157 |
+
data["root"]["children"]
|
158 |
+
if "root" in data and "children" in data["root"]
|
159 |
+
else []
|
160 |
+
)
|
161 |
+
updated_content = SearchResult(results=search_results, query_id=None)
|
162 |
+
return updated_content
|
163 |
|
164 |
|
165 |
@rt("/app")
|
|
|
167 |
return Layout(Div(P(f"Connected to Vespa at {vespa_app.url}"), cls="p-4"))
|
168 |
|
169 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
170 |
if __name__ == "__main__":
|
171 |
# ModelManager.get_instance() # Initialize once at startup
|
172 |
serve(port=7860)
|
output.css
CHANGED
@@ -1073,12 +1073,16 @@ body {
|
|
1073 |
resize: both;
|
1074 |
}
|
1075 |
|
1076 |
-
.grid-
|
1077 |
-
grid-
|
1078 |
}
|
1079 |
|
1080 |
-
.grid-cols-
|
1081 |
-
grid-template-columns:
|
|
|
|
|
|
|
|
|
1082 |
}
|
1083 |
|
1084 |
.flex-col {
|
@@ -1089,6 +1093,14 @@ body {
|
|
1089 |
flex-direction: column-reverse;
|
1090 |
}
|
1091 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1092 |
.items-center {
|
1093 |
align-items: center;
|
1094 |
}
|
@@ -1121,18 +1133,34 @@ body {
|
|
1121 |
gap: 2rem;
|
1122 |
}
|
1123 |
|
|
|
|
|
|
|
|
|
1124 |
.gap-px {
|
1125 |
gap: 1px;
|
1126 |
}
|
1127 |
|
1128 |
-
.gap-3 {
|
1129 |
-
gap: 0.75rem;
|
|
|
|
|
|
|
|
|
|
|
|
|
1130 |
}
|
1131 |
|
1132 |
.gap-y-4 {
|
1133 |
row-gap: 1rem;
|
1134 |
}
|
1135 |
|
|
|
|
|
|
|
|
|
|
|
|
|
1136 |
.space-x-3 > :not([hidden]) ~ :not([hidden]) {
|
1137 |
--tw-space-x-reverse: 0;
|
1138 |
margin-right: calc(0.75rem * var(--tw-space-x-reverse));
|
@@ -1193,6 +1221,10 @@ body {
|
|
1193 |
border-radius: calc(var(--radius) - 2px);
|
1194 |
}
|
1195 |
|
|
|
|
|
|
|
|
|
1196 |
.rounded-sm {
|
1197 |
border-radius: calc(var(--radius) - 4px);
|
1198 |
}
|
@@ -1339,8 +1371,8 @@ body {
|
|
1339 |
padding: 0.25rem;
|
1340 |
}
|
1341 |
|
1342 |
-
.p-
|
1343 |
-
padding:
|
1344 |
}
|
1345 |
|
1346 |
.p-3 {
|
@@ -1359,8 +1391,9 @@ body {
|
|
1359 |
padding: 1px;
|
1360 |
}
|
1361 |
|
1362 |
-
.
|
1363 |
-
padding:
|
|
|
1364 |
}
|
1365 |
|
1366 |
.px-2\.5 {
|
@@ -1448,6 +1481,10 @@ body {
|
|
1448 |
vertical-align: middle;
|
1449 |
}
|
1450 |
|
|
|
|
|
|
|
|
|
1451 |
.text-2xl {
|
1452 |
font-size: 1.5rem;
|
1453 |
line-height: 2rem;
|
@@ -1496,14 +1533,14 @@ body {
|
|
1496 |
font-weight: 500;
|
1497 |
}
|
1498 |
|
1499 |
-
.font-semibold {
|
1500 |
-
font-weight: 600;
|
1501 |
-
}
|
1502 |
-
|
1503 |
.font-normal {
|
1504 |
font-weight: 400;
|
1505 |
}
|
1506 |
|
|
|
|
|
|
|
|
|
1507 |
.leading-none {
|
1508 |
line-height: 1;
|
1509 |
}
|
@@ -1574,10 +1611,6 @@ body {
|
|
1574 |
color: transparent;
|
1575 |
}
|
1576 |
|
1577 |
-
.underline {
|
1578 |
-
text-decoration-line: underline;
|
1579 |
-
}
|
1580 |
-
|
1581 |
.no-underline {
|
1582 |
text-decoration-line: none;
|
1583 |
}
|
@@ -1908,6 +1941,19 @@ body {
|
|
1908 |
animation: slide-up 1s ease-out forwards;
|
1909 |
}
|
1910 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1911 |
:root:has(.data-\[state\=open\]\:no-bg-scroll[data-state="open"]) {
|
1912 |
overflow: hidden;
|
1913 |
}
|
@@ -1964,13 +2010,9 @@ body {
|
|
1964 |
--tw-ring-offset-width: 2px;
|
1965 |
}
|
1966 |
|
1967 |
-
.hover\:border-
|
1968 |
--tw-border-opacity: 1;
|
1969 |
-
border-color: rgb(
|
1970 |
-
}
|
1971 |
-
|
1972 |
-
.hover\:border-\[text-muted-foreground\]:hover {
|
1973 |
-
border-color: text-muted-foreground;
|
1974 |
}
|
1975 |
|
1976 |
.hover\:bg-accent:hover {
|
@@ -2001,10 +2043,6 @@ body {
|
|
2001 |
background-color: hsl(var(--secondary) / 0.8);
|
2002 |
}
|
2003 |
|
2004 |
-
.hover\:bg-secondary:hover {
|
2005 |
-
background-color: hsl(var(--secondary));
|
2006 |
-
}
|
2007 |
-
|
2008 |
.hover\:text-accent-foreground:hover {
|
2009 |
color: hsl(var(--accent-foreground));
|
2010 |
}
|
@@ -2013,14 +2051,6 @@ body {
|
|
2013 |
color: hsl(var(--foreground));
|
2014 |
}
|
2015 |
|
2016 |
-
.hover\:text-primary-foreground:hover {
|
2017 |
-
color: hsl(var(--primary-foreground));
|
2018 |
-
}
|
2019 |
-
|
2020 |
-
.hover\:text-muted-foreground:hover {
|
2021 |
-
color: hsl(var(--muted-foreground));
|
2022 |
-
}
|
2023 |
-
|
2024 |
.hover\:underline:hover {
|
2025 |
text-decoration-line: underline;
|
2026 |
}
|
@@ -2407,10 +2437,18 @@ body {
|
|
2407 |
}
|
2408 |
|
2409 |
@media (min-width: 768px) {
|
|
|
|
|
|
|
|
|
2410 |
.md\:max-w-\[420px\] {
|
2411 |
max-width: 420px;
|
2412 |
}
|
2413 |
|
|
|
|
|
|
|
|
|
2414 |
.md\:text-2xl {
|
2415 |
font-size: 1.5rem;
|
2416 |
line-height: 2rem;
|
@@ -2460,14 +2498,9 @@ body {
|
|
2460 |
--tw-gradient-to: #d1d5db var(--tw-gradient-to-position);
|
2461 |
}
|
2462 |
|
2463 |
-
.dark\:hover\:border-
|
2464 |
-
--tw-border-opacity: 1;
|
2465 |
-
border-color: rgb(0 0 0 / var(--tw-border-opacity));
|
2466 |
-
}
|
2467 |
-
|
2468 |
-
.hover\:dark\:border-black:where(.dark, .dark *):hover {
|
2469 |
--tw-border-opacity: 1;
|
2470 |
-
border-color: rgb(
|
2471 |
}
|
2472 |
|
2473 |
.\[\&\:has\(\[role\=checkbox\]\)\]\:pr-0:has([role=checkbox]) {
|
|
|
1073 |
resize: both;
|
1074 |
}
|
1075 |
|
1076 |
+
.grid-flow-col {
|
1077 |
+
grid-auto-flow: column;
|
1078 |
}
|
1079 |
|
1080 |
+
.grid-cols-1 {
|
1081 |
+
grid-template-columns: repeat(1, minmax(0, 1fr));
|
1082 |
+
}
|
1083 |
+
|
1084 |
+
.grid-cols-2 {
|
1085 |
+
grid-template-columns: repeat(2, minmax(0, 1fr));
|
1086 |
}
|
1087 |
|
1088 |
.flex-col {
|
|
|
1093 |
flex-direction: column-reverse;
|
1094 |
}
|
1095 |
|
1096 |
+
.flex-wrap {
|
1097 |
+
flex-wrap: wrap;
|
1098 |
+
}
|
1099 |
+
|
1100 |
+
.content-start {
|
1101 |
+
align-content: flex-start;
|
1102 |
+
}
|
1103 |
+
|
1104 |
.items-center {
|
1105 |
align-items: center;
|
1106 |
}
|
|
|
1133 |
gap: 2rem;
|
1134 |
}
|
1135 |
|
1136 |
+
.gap-\[3px\] {
|
1137 |
+
gap: 3px;
|
1138 |
+
}
|
1139 |
+
|
1140 |
.gap-px {
|
1141 |
gap: 1px;
|
1142 |
}
|
1143 |
|
1144 |
+
.gap-x-3 {
|
1145 |
+
-moz-column-gap: 0.75rem;
|
1146 |
+
column-gap: 0.75rem;
|
1147 |
+
}
|
1148 |
+
|
1149 |
+
.gap-x-5 {
|
1150 |
+
-moz-column-gap: 1.25rem;
|
1151 |
+
column-gap: 1.25rem;
|
1152 |
}
|
1153 |
|
1154 |
.gap-y-4 {
|
1155 |
row-gap: 1rem;
|
1156 |
}
|
1157 |
|
1158 |
+
.space-x-2 > :not([hidden]) ~ :not([hidden]) {
|
1159 |
+
--tw-space-x-reverse: 0;
|
1160 |
+
margin-right: calc(0.5rem * var(--tw-space-x-reverse));
|
1161 |
+
margin-left: calc(0.5rem * calc(1 - var(--tw-space-x-reverse)));
|
1162 |
+
}
|
1163 |
+
|
1164 |
.space-x-3 > :not([hidden]) ~ :not([hidden]) {
|
1165 |
--tw-space-x-reverse: 0;
|
1166 |
margin-right: calc(0.75rem * var(--tw-space-x-reverse));
|
|
|
1221 |
border-radius: calc(var(--radius) - 2px);
|
1222 |
}
|
1223 |
|
1224 |
+
.rounded-none {
|
1225 |
+
border-radius: 0px;
|
1226 |
+
}
|
1227 |
+
|
1228 |
.rounded-sm {
|
1229 |
border-radius: calc(var(--radius) - 4px);
|
1230 |
}
|
|
|
1371 |
padding: 0.25rem;
|
1372 |
}
|
1373 |
|
1374 |
+
.p-10 {
|
1375 |
+
padding: 2.5rem;
|
1376 |
}
|
1377 |
|
1378 |
.p-3 {
|
|
|
1391 |
padding: 1px;
|
1392 |
}
|
1393 |
|
1394 |
+
.px-2 {
|
1395 |
+
padding-left: 0.5rem;
|
1396 |
+
padding-right: 0.5rem;
|
1397 |
}
|
1398 |
|
1399 |
.px-2\.5 {
|
|
|
1481 |
vertical-align: middle;
|
1482 |
}
|
1483 |
|
1484 |
+
.font-mono {
|
1485 |
+
font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, "Liberation Mono", "Courier New", monospace;
|
1486 |
+
}
|
1487 |
+
|
1488 |
.text-2xl {
|
1489 |
font-size: 1.5rem;
|
1490 |
line-height: 2rem;
|
|
|
1533 |
font-weight: 500;
|
1534 |
}
|
1535 |
|
|
|
|
|
|
|
|
|
1536 |
.font-normal {
|
1537 |
font-weight: 400;
|
1538 |
}
|
1539 |
|
1540 |
+
.font-semibold {
|
1541 |
+
font-weight: 600;
|
1542 |
+
}
|
1543 |
+
|
1544 |
.leading-none {
|
1545 |
line-height: 1;
|
1546 |
}
|
|
|
1611 |
color: transparent;
|
1612 |
}
|
1613 |
|
|
|
|
|
|
|
|
|
1614 |
.no-underline {
|
1615 |
text-decoration-line: none;
|
1616 |
}
|
|
|
1941 |
animation: slide-up 1s ease-out forwards;
|
1942 |
}
|
1943 |
|
1944 |
+
.sim-map-button.active {
|
1945 |
+
background-color: #61D790;
|
1946 |
+
color: #2E2F27;
|
1947 |
+
&:hover {
|
1948 |
+
background-color: #61D790;
|
1949 |
+
}
|
1950 |
+
}
|
1951 |
+
|
1952 |
+
.tokens-button {
|
1953 |
+
background-color: #B7E2F1;
|
1954 |
+
color: #2E2F27;
|
1955 |
+
}
|
1956 |
+
|
1957 |
:root:has(.data-\[state\=open\]\:no-bg-scroll[data-state="open"]) {
|
1958 |
overflow: hidden;
|
1959 |
}
|
|
|
2010 |
--tw-ring-offset-width: 2px;
|
2011 |
}
|
2012 |
|
2013 |
+
.hover\:border-black:hover {
|
2014 |
--tw-border-opacity: 1;
|
2015 |
+
border-color: rgb(0 0 0 / var(--tw-border-opacity));
|
|
|
|
|
|
|
|
|
2016 |
}
|
2017 |
|
2018 |
.hover\:bg-accent:hover {
|
|
|
2043 |
background-color: hsl(var(--secondary) / 0.8);
|
2044 |
}
|
2045 |
|
|
|
|
|
|
|
|
|
2046 |
.hover\:text-accent-foreground:hover {
|
2047 |
color: hsl(var(--accent-foreground));
|
2048 |
}
|
|
|
2051 |
color: hsl(var(--foreground));
|
2052 |
}
|
2053 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2054 |
.hover\:underline:hover {
|
2055 |
text-decoration-line: underline;
|
2056 |
}
|
|
|
2437 |
}
|
2438 |
|
2439 |
@media (min-width: 768px) {
|
2440 |
+
.md\:block {
|
2441 |
+
display: block;
|
2442 |
+
}
|
2443 |
+
|
2444 |
.md\:max-w-\[420px\] {
|
2445 |
max-width: 420px;
|
2446 |
}
|
2447 |
|
2448 |
+
.md\:grid-cols-2 {
|
2449 |
+
grid-template-columns: repeat(2, minmax(0, 1fr));
|
2450 |
+
}
|
2451 |
+
|
2452 |
.md\:text-2xl {
|
2453 |
font-size: 1.5rem;
|
2454 |
line-height: 2rem;
|
|
|
2498 |
--tw-gradient-to: #d1d5db var(--tw-gradient-to-position);
|
2499 |
}
|
2500 |
|
2501 |
+
.dark\:hover\:border-white:hover:where(.dark, .dark *) {
|
|
|
|
|
|
|
|
|
|
|
2502 |
--tw-border-opacity: 1;
|
2503 |
+
border-color: rgb(255 255 255 / var(--tw-border-opacity));
|
2504 |
}
|
2505 |
|
2506 |
.\[\&\:has\(\[role\=checkbox\]\)\]\:pr-0:has([role=checkbox]) {
|
ruff.toml
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Exclude a variety of commonly ignored directories.
|
2 |
+
exclude = [
|
3 |
+
".bzr",
|
4 |
+
".direnv",
|
5 |
+
".eggs",
|
6 |
+
".git",
|
7 |
+
".git-rewrite",
|
8 |
+
".hg",
|
9 |
+
".ipynb_checkpoints",
|
10 |
+
".mypy_cache",
|
11 |
+
".nox",
|
12 |
+
".pants.d",
|
13 |
+
".pyenv",
|
14 |
+
".pytest_cache",
|
15 |
+
".pytype",
|
16 |
+
".ruff_cache",
|
17 |
+
".svn",
|
18 |
+
".tox",
|
19 |
+
".venv",
|
20 |
+
".vscode",
|
21 |
+
"__pypackages__",
|
22 |
+
"_build",
|
23 |
+
"buck-out",
|
24 |
+
"build",
|
25 |
+
"dist",
|
26 |
+
"node_modules",
|
27 |
+
"site-packages",
|
28 |
+
"venv",
|
29 |
+
]
|
30 |
+
|
31 |
+
# Same as Black.
|
32 |
+
line-length = 88
|
33 |
+
indent-width = 4
|
34 |
+
|
35 |
+
# Assume Python 3.8
|
36 |
+
target-version = "py38"
|
37 |
+
|
38 |
+
[lint]
|
39 |
+
# Enable Pyflakes (`F`) and a subset of the pycodestyle (`E`) codes by default.
|
40 |
+
# Unlike Flake8, Ruff doesn't enable pycodestyle warnings (`W`) or
|
41 |
+
# McCabe complexity (`C901`) by default.
|
42 |
+
select = ["E4", "E7", "E9", "F"]
|
43 |
+
ignore = []
|
44 |
+
|
45 |
+
# Allow fix for all enabled rules (when `--fix`) is provided.
|
46 |
+
fixable = ["ALL"]
|
47 |
+
unfixable = []
|
48 |
+
|
49 |
+
# Allow unused variables when underscore-prefixed.
|
50 |
+
dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"
|
51 |
+
|
52 |
+
[format]
|
53 |
+
# Like Black, use double quotes for strings.
|
54 |
+
quote-style = "double"
|
55 |
+
|
56 |
+
# Like Black, indent with spaces, rather than tabs.
|
57 |
+
indent-style = "space"
|
58 |
+
|
59 |
+
# Like Black, respect magic trailing commas.
|
60 |
+
skip-magic-trailing-comma = false
|
61 |
+
|
62 |
+
# Like Black, automatically detect the appropriate line ending.
|
63 |
+
line-ending = "auto"
|
64 |
+
|
65 |
+
# Enable auto-formatting of code examples in docstrings. Markdown,
|
66 |
+
# reStructuredText code/literal blocks and doctests are all supported.
|
67 |
+
#
|
68 |
+
# This is currently disabled by default, but it is planned for this
|
69 |
+
# to be opt-out in the future.
|
70 |
+
docstring-code-format = false
|
71 |
+
|
72 |
+
# Set the line length limit used when formatting code snippets in
|
73 |
+
# docstrings.
|
74 |
+
#
|
75 |
+
# This only has an effect when the `docstring-code-format` setting is
|
76 |
+
# enabled.
|
77 |
+
docstring-code-line-length = "dynamic"
|