thomasht86 commited on
Commit
94df778
·
verified ·
1 Parent(s): ece4c70

Upload folder using huggingface_hub

Browse files
Files changed (4) hide show
  1. .env.example +2 -1
  2. backend/colpali.py +10 -1
  3. backend/vespa_app.py +3 -0
  4. main.py +5 -3
.env.example CHANGED
@@ -9,4 +9,5 @@ VESPA_CLOUD_MTLS_KEY="-----BEGIN PRIVATE KEY-----
9
  -----END PRIVATE KEY-----"
10
  VESPA_CLOUD_MTLS_CERT="-----BEGIN CERTIFICATE-----
11
  ...
12
- -----END CERTIFICATE-----"
 
 
9
  -----END PRIVATE KEY-----"
10
  VESPA_CLOUD_MTLS_CERT="-----BEGIN CERTIFICATE-----
11
  ...
12
+ -----END CERTIFICATE-----"
13
+ HOT_RELOAD=true
backend/colpali.py CHANGED
@@ -309,6 +309,8 @@ def add_sim_maps_to_result(
309
  vespa_sim_map = single_result["fields"].get("summaryfeatures", None)
310
  if vespa_sim_map:
311
  vespa_sim_maps.append(vespa_sim_map)
 
 
312
  sim_map_imgs_generator = gen_similarity_maps(
313
  model=model,
314
  processor=processor,
@@ -322,7 +324,14 @@ def add_sim_maps_to_result(
322
  )
323
  for img_idx, token, sim_mapb64 in sim_map_imgs_generator:
324
  print(f"Created sim map for image {img_idx} and token {token}")
325
- result["root"]["children"][img_idx]["fields"][f"sim_map_{token}"] = sim_mapb64
 
 
 
 
 
 
 
326
  # Update result_cache with the new sim_map
327
  result_cache.set(query_id, result)
328
  # for single_result, sim_map_dict in zip(result["root"]["children"], sim_map_imgs):
 
309
  vespa_sim_map = single_result["fields"].get("summaryfeatures", None)
310
  if vespa_sim_map:
311
  vespa_sim_maps.append(vespa_sim_map)
312
+ if not imgs:
313
+ return result
314
  sim_map_imgs_generator = gen_similarity_maps(
315
  model=model,
316
  processor=processor,
 
324
  )
325
  for img_idx, token, sim_mapb64 in sim_map_imgs_generator:
326
  print(f"Created sim map for image {img_idx} and token {token}")
327
+ if (
328
+ len(result["root"]["children"]) > img_idx
329
+ and "fields" in result["root"]["children"][img_idx]
330
+ and "sim_map" in result["root"]["children"][img_idx]["fields"]
331
+ ):
332
+ result["root"]["children"][img_idx]["fields"][f"sim_map_{token}"] = (
333
+ sim_mapb64
334
+ )
335
  # Update result_cache with the new sim_map
336
  result_cache.set(query_id, result)
337
  # for single_result, sim_map_dict in zip(result["root"]["children"], sim_map_imgs):
backend/vespa_app.py CHANGED
@@ -279,6 +279,9 @@ class VespaQueryClient:
279
  raise ValueError(f"Unsupported ranking: {ranking}")
280
 
281
  # Print score, title id, and text of the results
 
 
 
282
  for idx, child in enumerate(result["root"]["children"]):
283
  print(
284
  f"Result {idx+1}: {child['relevance']}, {child['fields']['title']}, {child['fields']['id']}"
 
279
  raise ValueError(f"Unsupported ranking: {ranking}")
280
 
281
  # Print score, title id, and text of the results
282
+ if "root" not in result or "children" not in result["root"]:
283
+ result["root"] = {"children": []}
284
+ return result
285
  for idx, child in enumerate(result["root"]["children"]):
286
  print(
287
  f"Result {idx+1}: {child['relevance']}, {child['fields']['title']}, {child['fields']['id']}"
main.py CHANGED
@@ -1,12 +1,12 @@
1
  import asyncio
2
  import base64
3
- import hashlib
4
  import io
5
  import os
6
  import time
7
  from concurrent.futures import ThreadPoolExecutor
8
  from functools import partial
9
  from pathlib import Path
 
10
 
11
  import google.generativeai as genai
12
  from fasthtml.common import *
@@ -112,7 +112,7 @@ async def keepalive():
112
 
113
 
114
  def generate_query_id(query):
115
- return hashlib.md5(query.encode("utf-8")).hexdigest()
116
 
117
 
118
  @rt("/static/{filepath:path}")
@@ -394,4 +394,6 @@ def get():
394
 
395
  if __name__ == "__main__":
396
  # ModelManager.get_instance() # Initialize once at startup
397
- serve(port=7860, reload=False)
 
 
 
1
  import asyncio
2
  import base64
 
3
  import io
4
  import os
5
  import time
6
  from concurrent.futures import ThreadPoolExecutor
7
  from functools import partial
8
  from pathlib import Path
9
+ import uuid
10
 
11
  import google.generativeai as genai
12
  from fasthtml.common import *
 
112
 
113
 
114
  def generate_query_id(query):
115
+ return uuid.uuid4().hex
116
 
117
 
118
  @rt("/static/{filepath:path}")
 
394
 
395
  if __name__ == "__main__":
396
  # ModelManager.get_instance() # Initialize once at startup
397
+ HOT_RELOAD = os.getenv("HOT_RELOAD", "False").lower() == "true"
398
+ print(f"Starting app with hot reload: {HOT_RELOAD}")
399
+ serve(port=7860, reload=HOT_RELOAD)