rbiswasfc commited on
Commit
7324658
1 Parent(s): 1f1f070
Files changed (3) hide show
  1. .gitignore +2 -0
  2. app.py +141 -0
  3. requirements.txt +9 -0
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ .env
2
+ data
app.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import ClassVar
3
+
4
+ # import dotenv
5
+ import gradio as gr
6
+ import lancedb
7
+ import srsly
8
+ from huggingface_hub import snapshot_download
9
+ from lancedb.embeddings.base import TextEmbeddingFunction
10
+ from lancedb.embeddings.registry import register
11
+ from lancedb.pydantic import LanceModel, Vector
12
+ from lancedb.rerankers import CohereReranker, ColbertReranker
13
+ from lancedb.util import attempt_import_or_raise
14
+
15
+ # dotenv.load_dotenv()
16
+
17
+
18
+ @register("coherev3")
19
+ class CohereEmbeddingFunction_2(TextEmbeddingFunction):
20
+ name: str = "embed-english-v3.0"
21
+ client: ClassVar = None
22
+
23
+ def ndims(self):
24
+ return 768
25
+
26
+ def generate_embeddings(self, texts):
27
+ """
28
+ Get the embeddings for the given texts
29
+
30
+ Parameters
31
+ ----------
32
+ texts: list[str] or np.ndarray (of str)
33
+ The texts to embed
34
+ """
35
+ # TODO retry, rate limit, token limit
36
+ self._init_client()
37
+ rs = CohereEmbeddingFunction_2.client.embed(
38
+ texts=texts, model=self.name, input_type="search_document"
39
+ )
40
+
41
+ return [emb for emb in rs.embeddings]
42
+
43
+ def _init_client(self):
44
+ cohere = attempt_import_or_raise("cohere")
45
+ if CohereEmbeddingFunction_2.client is None:
46
+ CohereEmbeddingFunction_2.client = cohere.Client(
47
+ os.environ["COHERE_API_KEY"]
48
+ )
49
+
50
+
51
+ COHERE_EMBEDDER = CohereEmbeddingFunction_2.create()
52
+
53
+
54
+ class ArxivModel(LanceModel):
55
+ text: str = COHERE_EMBEDDER.SourceField()
56
+ vector: Vector(1024) = COHERE_EMBEDDER.VectorField()
57
+ title: str
58
+ paper_title: str
59
+ content_type: str
60
+ arxiv_id: str
61
+
62
+
63
+ def download_data():
64
+ snapshot_download(
65
+ repo_id="rbiswasfc/zotero_db",
66
+ repo_type="dataset",
67
+ local_dir="./data",
68
+ token=os.environ["HF_TOKEN"],
69
+ )
70
+ print("Data downloaded!")
71
+
72
+
73
+ download_data()
74
+
75
+ VERSION = "0.0.0a"
76
+ DB = lancedb.connect("./data/.lancedb_zotero_v0")
77
+ ID_TO_ABSTRACT = srsly.read_json("./data/id_to_abstract.json")
78
+ RERANKERS = {"colbert": ColbertReranker(), "cohere": CohereReranker()}
79
+ TBL = DB.open_table("arxiv_zotero_v0")
80
+
81
+
82
+ def _format_results(arxiv_refs):
83
+ results = []
84
+ for arx_id, paper_title in arxiv_refs.items():
85
+ abstract = ID_TO_ABSTRACT.get(arx_id, "")
86
+ # these are all ugly hacks because the data preprocessing is poor. to be fixed v soon.
87
+ if "Abstract\n\n" in abstract:
88
+ abstract = abstract.split("Abstract\n\n")[-1]
89
+ if paper_title in abstract:
90
+ abstract = abstract.split(paper_title)[-1]
91
+ if abstract.startswith("\n"):
92
+ abstract = abstract[1:]
93
+ if "\n\n" in abstract[:20]:
94
+ abstract = "\n\n".join(abstract.split("\n\n")[1:])
95
+ result = {
96
+ "title": paper_title,
97
+ "url": f"https://arxiv.org/abs/{arx_id}",
98
+ "abstract": abstract,
99
+ }
100
+ results.append(result)
101
+
102
+ return results
103
+
104
+
105
+ def query_db(query: str, k: int = 10, reranker: str = "cohere"):
106
+ raw_results = TBL.search(query, query_type="hybrid").limit(k)
107
+ if reranker is not None:
108
+ ranked_results = raw_results.rerank(reranker=RERANKERS[reranker])
109
+ else:
110
+ ranked_results = raw_results
111
+
112
+ ranked_results = ranked_results.to_pandas()
113
+ top_results = ranked_results.groupby("arxiv_id").agg({"_relevance_score": "sum"})
114
+ top_results = top_results.sort_values(by="_relevance_score", ascending=False).head(
115
+ 3
116
+ )
117
+ top_results_dict = {
118
+ row["arxiv_id"]: row["paper_title"]
119
+ for index, row in ranked_results.iterrows()
120
+ if row["arxiv_id"] in top_results.index
121
+ }
122
+
123
+ final_results = _format_results(top_results_dict)
124
+ return final_results
125
+
126
+
127
+ with gr.Blocks() as demo:
128
+ with gr.Row():
129
+ query = gr.Textbox(label="Query", placeholder="Enter your query...")
130
+ submit_btn = gr.Button("Submit")
131
+ output = gr.JSON(label="Search Results")
132
+
133
+ # # callback ---
134
+ submit_btn.click(
135
+ fn=query_db,
136
+ inputs=query,
137
+ outputs=output,
138
+ )
139
+
140
+
141
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ transformers
2
+ torch
3
+ lancedb
4
+ srsly
5
+ cohere
6
+ python-dotenv
7
+ tantivy
8
+ beautifulsoup4
9
+ retry