import os from typing import ClassVar # import dotenv import gradio as gr import lancedb import srsly from huggingface_hub import snapshot_download from lancedb.embeddings.base import TextEmbeddingFunction from lancedb.embeddings.registry import register from lancedb.pydantic import LanceModel, Vector from lancedb.rerankers import CohereReranker, ColbertReranker from lancedb.util import attempt_import_or_raise # dotenv.load_dotenv() @register("coherev3") class CohereEmbeddingFunction_2(TextEmbeddingFunction): name: str = "embed-english-v3.0" client: ClassVar = None def ndims(self): return 768 def generate_embeddings(self, texts): """ Get the embeddings for the given texts Parameters ---------- texts: list[str] or np.ndarray (of str) The texts to embed """ # TODO retry, rate limit, token limit self._init_client() rs = CohereEmbeddingFunction_2.client.embed( texts=texts, model=self.name, input_type="search_document" ) return [emb for emb in rs.embeddings] def _init_client(self): cohere = attempt_import_or_raise("cohere") if CohereEmbeddingFunction_2.client is None: CohereEmbeddingFunction_2.client = cohere.Client( os.environ["COHERE_API_KEY"] ) COHERE_EMBEDDER = CohereEmbeddingFunction_2.create() class ArxivModel(LanceModel): text: str = COHERE_EMBEDDER.SourceField() vector: Vector(1024) = COHERE_EMBEDDER.VectorField() title: str paper_title: str content_type: str arxiv_id: str def download_data(): snapshot_download( repo_id="rbiswasfc/zotero_db", repo_type="dataset", local_dir="./data", token=os.environ["HF_TOKEN"], ) print("Data downloaded!") download_data() VERSION = "0.0.0a" DB = lancedb.connect("./data/.lancedb_zotero_v0") ID_TO_ABSTRACT = srsly.read_json("./data/id_to_abstract.json") RERANKERS = {"colbert": ColbertReranker(), "cohere": CohereReranker()} TBL = DB.open_table("arxiv_zotero_v0") def _format_results(arxiv_refs): results = [] for arx_id, paper_title in arxiv_refs.items(): abstract = ID_TO_ABSTRACT.get(arx_id, "") # these are all ugly hacks because the data preprocessing is poor. to be fixed v soon. if "Abstract\n\n" in abstract: abstract = abstract.split("Abstract\n\n")[-1] if paper_title in abstract: abstract = abstract.split(paper_title)[-1] if abstract.startswith("\n"): abstract = abstract[1:] if "\n\n" in abstract[:20]: abstract = "\n\n".join(abstract.split("\n\n")[1:]) result = { "title": paper_title, "url": f"https://arxiv.org/abs/{arx_id}", "abstract": abstract, } results.append(result) return results def query_db(query: str, k: int = 10, reranker: str = "cohere"): raw_results = TBL.search(query, query_type="hybrid").limit(k) if reranker is not None: ranked_results = raw_results.rerank(reranker=RERANKERS[reranker]) else: ranked_results = raw_results ranked_results = ranked_results.to_pandas() top_results = ranked_results.groupby("arxiv_id").agg({"_relevance_score": "sum"}) top_results = top_results.sort_values(by="_relevance_score", ascending=False).head( 3 ) top_results_dict = { row["arxiv_id"]: row["paper_title"] for index, row in ranked_results.iterrows() if row["arxiv_id"] in top_results.index } final_results = _format_results(top_results_dict) return final_results with gr.Blocks() as demo: with gr.Row(): query = gr.Textbox(label="Query", placeholder="Enter your query...") submit_btn = gr.Button("Submit") output = gr.JSON(label="Search Results") # # callback --- submit_btn.click( fn=query_db, inputs=query, outputs=output, ) demo.launch()