Rams901 commited on
Commit
e72ac74
1 Parent(s): 16b2cfc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -5
app.py CHANGED
@@ -23,11 +23,48 @@ from typing import Optional, List, Mapping, Any
23
  import ast
24
  from utils import ClaudeLLM
25
 
26
- embeddings = HuggingFaceEmbeddings()
27
- db_art = FAISS.load_local('db_art', embeddings)
28
- db_yt = FAISS.load_local('db_yt', embeddings)
 
 
 
29
  mp_docs = {}
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  def retrieve_thoughts(query, n, db):
32
 
33
  # print(db.similarity_search_with_score(query = query, k = k, fetch_k = k*10))
@@ -66,7 +103,7 @@ def qa_retrieve_art(query,):
66
  global db_art
67
 
68
  global mp_docs
69
- thoughts = retrieve_thoughts(query, 0, db_art)
70
  if not(thoughts):
71
 
72
  if mp_docs:
@@ -116,6 +153,6 @@ ref_art = gr.Interface(fn=qa_retrieve_art, label="Articles",
116
  ref_yt = gr.Interface(fn=qa_retrieve_yt, label="Youtube",
117
  inputs=gr.inputs.Textbox(lines=5, label="what would you like to learn about?"),
118
  outputs = gr.components.JSON(label="youtube"),title = "youtube", examples=examples)
119
- demo = gr.Parallel( ref_art, ref_yt)
120
 
121
  demo.launch()
 
23
  import ast
24
  from utils import ClaudeLLM
25
 
26
+ from qdrant_client import models, QdrantClient
27
+ from sentence_transformers import SentenceTransformer
28
+
29
+ # embeddings = HuggingFaceEmbeddings()
30
+ # db_art = FAISS.load_local('db_art', embeddings)
31
+ # db_yt = FAISS.load_local('db_yt', embeddings)
32
  mp_docs = {}
33
 
34
+ qdrant = QdrantClient(
35
+ "https://0a1b865d-8291-41ef-8c29-ca6c35e26391.us-east4-0.gcp.cloud.qdrant.io:6333",
36
+ prefer_grpc=True,
37
+ api_key=os.env['Qdrant_Api_Key']
38
+ )
39
+ encoder = SentenceTransformer('BAAI/bge-large-en-v1.5')
40
+ def q_retrieve_thoughts(query, n, db = "articles"):
41
+ qdrant.search(
42
+ collection_name="articles",
43
+ query_vector=encoder.encode("Will Russia win the war in Ukraine?").tolist(),
44
+ limit=4000 # TO-DO: know the right number of thoughts existing maybe using get_collection
45
+ )
46
+ df = pd.DataFrame.from_records([dict(hit) for hit in hits] )
47
+ payload = pd.DataFrame(list(df['payload'].values[:]))
48
+
49
+ # payload.rename(columns = ['id': 'order_id'])
50
+ # payload['id'] = df['id']
51
+
52
+ payload['score'] = df['score']
53
+ del df
54
+ payload.sort_values('score', ascending = False, inplace = True)
55
+
56
+ tier_1 = payload
57
+
58
+ chunks_1 = tier_1.groupby(['_id', ]).apply(lambda x: "\n...\n".join(x.sort_values('id')['page_content'].values)).values
59
+ score = tier_1.groupby(['_id', ]).apply(lambda x: x['score'].mean()).values
60
+
61
+ tier_1_adjusted = tier_1.groupby(['_id', ]).first().reset_index()[[ 'title', 'url', 'author']]
62
+ tier_1_adjusted['content'] = list(chunks_1)
63
+ tier_1_adjusted['score'] = score
64
+ tier_1_adjusted = tier_1_adjusted[tier_1_adjusted['score']>0.5]
65
+ tier_1_adjusted.sort_values('score', ascending = False, inplace = True)
66
+ return {'tier 1':tier_1_adjusted, }
67
+
68
  def retrieve_thoughts(query, n, db):
69
 
70
  # print(db.similarity_search_with_score(query = query, k = k, fetch_k = k*10))
 
103
  global db_art
104
 
105
  global mp_docs
106
+ thoughts = q_retrieve_thoughts(query, 0)
107
  if not(thoughts):
108
 
109
  if mp_docs:
 
153
  ref_yt = gr.Interface(fn=qa_retrieve_yt, label="Youtube",
154
  inputs=gr.inputs.Textbox(lines=5, label="what would you like to learn about?"),
155
  outputs = gr.components.JSON(label="youtube"),title = "youtube", examples=examples)
156
+ demo = gr.Parallel( ref_art,)
157
 
158
  demo.launch()