import json import os from datetime import datetime from typing import ClassVar # import dotenv import lancedb import srsly from fasthtml.common import * # noqa from fasthtml_hf import setup_hf_backup from huggingface_hub import snapshot_download from lancedb.embeddings.base import TextEmbeddingFunction from lancedb.embeddings.registry import register from lancedb.pydantic import LanceModel, Vector from lancedb.rerankers import CohereReranker, ColbertReranker from lancedb.util import attempt_import_or_raise # dotenv.load_dotenv() # download the zotero index (~1200 papers as of July 24, currently hosted on HF) ---- def download_data(): snapshot_download( repo_id="rbiswasfc/zotero_db", repo_type="dataset", local_dir="./data", token=os.environ["HF_TOKEN"], ) print("Data downloaded!") if not os.path.exists( "./data/.lancedb_zotero_v0" ): # TODO: implement a better check / refresh mechanism download_data() # cohere embedding utils ---- @register("coherev3") class CohereEmbeddingFunction_2(TextEmbeddingFunction): name: str = "embed-english-v3.0" client: ClassVar = None def ndims(self): return 768 def generate_embeddings(self, texts): """ Get the embeddings for the given texts Parameters ---------- texts: list[str] or np.ndarray (of str) The texts to embed """ # TODO retry, rate limit, token limit self._init_client() rs = CohereEmbeddingFunction_2.client.embed( texts=texts, model=self.name, input_type="search_document" ) return [emb for emb in rs.embeddings] def _init_client(self): cohere = attempt_import_or_raise("cohere") if CohereEmbeddingFunction_2.client is None: CohereEmbeddingFunction_2.client = cohere.Client( os.environ["COHERE_API_KEY"] ) COHERE_EMBEDDER = CohereEmbeddingFunction_2.create() # LanceDB model ---- class ArxivModel(LanceModel): text: str = COHERE_EMBEDDER.SourceField() vector: Vector(1024) = COHERE_EMBEDDER.VectorField() title: str paper_title: str content_type: str arxiv_id: str VERSION = "0.0.0a" DB = lancedb.connect("./data/.lancedb_zotero_v0") ID_TO_ABSTRACT = srsly.read_json("./data/id_to_abstract.json") RERANKERS = {"colbert": ColbertReranker(), "cohere": CohereReranker()} TBL = DB.open_table("arxiv_zotero_v0") # format results ---- def _format_results(arxiv_refs): results = [] for arx_id, paper_title in arxiv_refs.items(): abstract = ID_TO_ABSTRACT.get(arx_id, "") # these are all ugly hacks because the data preprocessing is poor. to be fixed v soon. if "Abstract\n\n" in abstract: abstract = abstract.split("Abstract\n\n")[-1] if paper_title in abstract: abstract = abstract.split(paper_title)[-1] if abstract.startswith("\n"): abstract = abstract[1:] if "\n\n" in abstract[:20]: abstract = "\n\n".join(abstract.split("\n\n")[1:]) result = { "title": paper_title, "url": f"https://arxiv.org/abs/{arx_id}", "abstract": abstract, } results.append(result) return results # Search logic ---- def query_db(query: str, k: int = 10, reranker: str = "cohere"): raw_results = TBL.search(query, query_type="hybrid").limit(k) if reranker is not None: ranked_results = raw_results.rerank(reranker=RERANKERS[reranker]) else: ranked_results = raw_results ranked_results = ranked_results.to_pandas() top_results = ranked_results.groupby("arxiv_id").agg({"_relevance_score": "sum"}) top_results = top_results.sort_values(by="_relevance_score", ascending=False).head( 3 ) top_results_dict = { row["arxiv_id"]: row["paper_title"] for index, row in ranked_results.iterrows() if row["arxiv_id"] in top_results.index } final_results = _format_results(top_results_dict) return final_results ########################################################################### # FastHTML app ----- ########################################################################### style = Style(""" :root { color-scheme: dark; } body { max-width: 1200px; margin: 0 auto; padding: 20px; line-height: 1.6; } #query { width: 100%; margin-bottom: 1rem; } #search-form button { width: 100%; } #search-results, #log-entries { margin-top: 2rem; } .log-entry { border: 1px solid #ccc; padding: 10px; margin-bottom: 10px; } .log-entry pre { white-space: pre-wrap; word-wrap: break-word; } .htmx-indicator { display: none; } .htmx-request .htmx-indicator { display: inline-block; } .spinner { display: inline-block; width: 2.5em; height: 2.5em; border: 0.3em solid rgba(255,255,255,.3); border-radius: 50%; border-top-color: #fff; animation: spin 1s ease-in-out infinite; margin-left: 10px; vertical-align: middle; } @keyframes spin { to { transform: rotate(360deg); } } .searching-text { font-size: 1.2em; font-weight: bold; color: #fff; margin-right: 10px; vertical-align: middle; } """) # get the fast app and route app, rt = fast_app(live=True, hdrs=(style,)) # Initialize a database to store search logs -- db = database("log_data/search_logs.db") search_logs = db.t.search_logs if search_logs not in db.t: search_logs.create( id=int, timestamp=str, query=str, results=str, pk="id", ) SearchLog = search_logs.dataclass() def insert_log_entry(log_entry): "Insert a log entry into the database" return search_logs.insert( SearchLog( timestamp=log_entry["timestamp"].isoformat(), query=log_entry["query"], results=json.dumps(log_entry["results"]), ) ) @rt("/") async def get(): query_form = Form( Textarea(id="query", name="query", placeholder="Enter your query..."), Button("Submit", type="submit"), Div( Span("Searching...", cls="searching-text htmx-indicator"), Span(cls="spinner htmx-indicator"), cls="indicator-container", ), id="search-form", hx_post="/search", hx_target="#search-results", hx_indicator=".indicator-container", ) # results_div = Div(H2("Search Results"), Div(id="search-results", cls="results-container")) results_div = Div(Div(id="search-results", cls="results-container")) view_logs_link = A("View Logs", href="/logs", cls="view-logs-link") return Titled( "Zotero Search", Div(query_form, results_div, view_logs_link, cls="container") ) def SearchResult(result): "Custom component for displaying a search result" return Card( H4(A(result["title"], href=result["url"], target="_blank")), P(result["abstract"]), footer=A("Read more →", href=result["url"], target="_blank"), ) def log_query_and_results(query, results): log_entry = { "timestamp": datetime.now(), "query": query, "results": [{"title": r["title"], "url": r["url"]} for r in results], } insert_log_entry(log_entry) @rt("/search") async def post(query: str): results = query_db(query) log_query_and_results(query, results) return Div(*[SearchResult(r) for r in results], id="search-results") def LogEntry(entry): return Div( H4(f"Query: {entry.query}"), P(f"Timestamp: {entry.timestamp}"), H5("Results:"), Pre(entry.results), cls="log-entry", ) @rt("/logs") async def get(): logs = search_logs(order_by="-id", limit=50) # Get the latest 50 logs log_entries = [LogEntry(log) for log in logs] return Titled( "Logs", Div( H2("Recent Search Logs"), Div(*log_entries, id="log-entries"), A("Back to Search", href="/", cls="back-link"), cls="container", ), ) if __name__ == "__main__": import uvicorn setup_hf_backup(app) uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("PORT", 7860))) # run_uv()