thomasht86's picture
Upload folder using huggingface_hub
4775a9f verified
raw
history blame
6.17 kB
import asyncio
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from fasthtml.common import *
from shad4fast import *
from vespa.application import Vespa
import time
from backend.colpali import (
get_result_from_query,
get_query_embeddings_and_token_map,
add_sim_maps_to_result,
)
from backend.vespa_app import get_vespa_app
from backend.cache import LRUCache
from backend.modelmanager import ModelManager
from frontend.app import Home, Search, SearchBox, SearchResult
from frontend.layout import Layout
import hashlib
highlight_js_theme_link = Link(id="highlight-theme", rel="stylesheet", href="")
highlight_js_theme = Script(src="/static/js/highlightjs-theme.js")
highlight_js = HighlightJS(
langs=["python", "javascript", "java", "json", "xml"],
dark="github-dark",
light="github",
)
app, rt = fast_app(
htmlkw={"cls": "h-full"},
pico=False,
hdrs=(
ShadHead(tw_cdn=False, theme_handle=True),
highlight_js,
highlight_js_theme_link,
highlight_js_theme,
),
)
vespa_app: Vespa = get_vespa_app()
result_cache = LRUCache(max_size=20) # Each result can be ~10MB
task_cache = LRUCache(
max_size=1000
) # Map from query_id to boolean value - False if not all results are ready.
thread_pool = ThreadPoolExecutor()
def generate_query_id(query):
return hashlib.md5(query.encode("utf-8")).hexdigest()
@rt("/static/{filepath:path}")
def serve_static(filepath: str):
return FileResponse(f"./static/{filepath}")
@rt("/")
def get():
return Layout(Home())
@rt("/search")
def get(request):
# Extract the 'query' and 'ranking' parameters from the URL
query_value = request.query_params.get("query", "").strip()
ranking_value = request.query_params.get("ranking", "nn+colpali")
print("/search: Fetching results for ranking_value:", ranking_value)
# Always render the SearchBox first
if not query_value:
# Show SearchBox and a message for missing query
return Layout(
Div(
SearchBox(query_value=query_value, ranking_value=ranking_value),
Div(
P(
"No query provided. Please enter a query.",
cls="text-center text-muted-foreground",
),
cls="p-10",
),
cls="grid",
)
)
# Show the loading message if a query is provided
return Layout(Search(request)) # Show SearchBox and Loading message initially
@rt("/fetch_results")
async def get(request, query: str, nn: bool = True):
if "hx-request" not in request.headers:
return RedirectResponse("/search")
# Extract ranking option from the request
ranking_value = request.query_params.get("ranking")
print(
f"/fetch_results: Fetching results for query: {query}, ranking: {ranking_value}"
)
# Generate a unique query_id based on the query and ranking value
query_id = generate_query_id(query + ranking_value)
# See if results are already in cache
if result_cache.get(query_id):
print(f"Results for query_id {query_id} already in cache")
result = result_cache.get(query_id)
search_results = get_results_children(result)
# If task is completed, return the results, but no query_id
if task_cache.get(query_id):
return SearchResult(search_results, None)
# If task is not completed, return the results with query_id
return SearchResult(search_results, query_id)
task_cache.set(query_id, False)
# Fetch model and processor
manager = ModelManager.get_instance()
model = manager.model
processor = manager.processor
q_embs, token_to_idx = get_query_embeddings_and_token_map(processor, model, query)
start = time.perf_counter()
# Fetch real search results from Vespa
result = await get_result_from_query(
app=vespa_app,
processor=processor,
model=model,
query=query,
q_embs=q_embs,
token_to_idx=token_to_idx,
ranking=ranking_value,
)
end = time.perf_counter()
print(
f"Search results fetched in {end - start:.2f} seconds, Vespa says searchtime was {result['timing']['searchtime']} seconds"
)
# Start generating the similarity map in the background
asyncio.create_task(
generate_similarity_map(
model, processor, query, q_embs, token_to_idx, result, query_id
)
)
search_results = get_results_children(result)
return SearchResult(search_results, query_id)
def get_results_children(result):
search_results = (
result["root"]["children"]
if "root" in result and "children" in result["root"]
else []
)
return search_results
async def generate_similarity_map(
model, processor, query, q_embs, token_to_idx, result, query_id
):
loop = asyncio.get_event_loop()
sim_map_task = partial(
add_sim_maps_to_result,
result=result,
model=model,
processor=processor,
query=query,
q_embs=q_embs,
token_to_idx=token_to_idx,
query_id=query_id,
result_cache=result_cache,
)
sim_map_result = await loop.run_in_executor(thread_pool, sim_map_task)
result_cache.set(query_id, sim_map_result)
task_cache.set(query_id, True)
@app.get("/updated_search_results")
async def updated_search_results(query_id: str):
result = result_cache.get(query_id)
if result is None:
return HTMLResponse(status_code=204)
search_results = get_results_children(result)
# Check if task is completed - Stop polling if it is
if task_cache.get(query_id):
updated_content = SearchResult(results=search_results, query_id=None)
else:
updated_content = SearchResult(results=search_results, query_id=query_id)
return updated_content
@rt("/app")
def get():
return Layout(Div(P(f"Connected to Vespa at {vespa_app.url}"), cls="p-4"))
if __name__ == "__main__":
# ModelManager.get_instance() # Initialize once at startup
serve(port=7860)