HonestAnnie commited on
Commit
2ec7158
1 Parent(s): 38d2199
Files changed (1) hide show
  1. app.py +16 -18
app.py CHANGED
@@ -1,21 +1,18 @@
1
  import os
2
- import requests
3
  import gradio as gr
4
  import chromadb
5
- import json
6
- import pandas as pd
7
 
8
  from sentence_transformers import SentenceTransformer
9
 
10
  import spaces
11
 
12
  @spaces.GPU
13
- def get_embeddings(text, task):
14
  model = SentenceTransformer("Linq-AI-Research/Linq-Embed-Mistral", use_auth_token=os.getenv("HF_TOKEN"))
15
  task = "Given a question, retrieve passages that answer the question"
16
- prompt = f"Instruct: {task}\nQuery: {text}" # Use text here
17
- query_embeddings = model.encode([prompt], convert_to_tensor=True) # Ensure it's a list
18
- return query_embeddings.cpu().numpy()
19
 
20
 
21
  # Initialize a persistent Chroma client and retrieve collection
@@ -56,27 +53,28 @@ def query_chroma(embeddings, authors, num_results=10):
56
 
57
  return formatted_results
58
  except Exception as e:
59
- return f"Failed to query the database: {str(e)}"
60
 
61
 
62
  # Main function
63
- def perform_query(query, task, author, num_results):
 
64
  embeddings = get_embeddings(query, task)
65
- initial_results = query_chroma(embeddings, author, num_results)
 
 
 
66
 
67
- results = [(f"{res['author']}, {res['book']}, Distance: {res['distance']}", res['text'], res['id']) for res in initial_results]
68
-
69
  updates = []
70
- for meta, text, id_ in results:
71
- markdown_content = f"**{meta}**\n\n{text}"
72
  updates.append(gr.update(visible=True, value=markdown_content))
73
  updates.append(gr.update(visible=True, value="Flag", elem_id=f"flag-{len(updates)//2}"))
74
- updates.append(gr.update(visible=False, value=id_)) # Hide the ID textbox
75
 
76
- updates += [gr.update(visible=False)] * (3 * (max_textboxes - len(results) // 3))
77
-
78
- return updates
79
 
 
80
 
81
  # Initialize the CSVLogger callback for flagging
82
  callback = gr.CSVLogger()
 
1
  import os
 
2
  import gradio as gr
3
  import chromadb
 
 
4
 
5
  from sentence_transformers import SentenceTransformer
6
 
7
  import spaces
8
 
9
  @spaces.GPU
10
+ def get_embeddings(query, task):
11
  model = SentenceTransformer("Linq-AI-Research/Linq-Embed-Mistral", use_auth_token=os.getenv("HF_TOKEN"))
12
  task = "Given a question, retrieve passages that answer the question"
13
+ prompt = f"Instruct: {task}\nQuery: {query}"
14
+ query_embeddings = model.encode([prompt])
15
+ return query_embeddings
16
 
17
 
18
  # Initialize a persistent Chroma client and retrieve collection
 
53
 
54
  return formatted_results
55
  except Exception as e:
56
+ return {"error": str(e)}
57
 
58
 
59
  # Main function
60
+ def perform_query(query, authors, num_results):
61
+ task = "Given a question, retrieve passages that answer the question"
62
  embeddings = get_embeddings(query, task)
63
+ results = query_chroma(embeddings, authors, num_results)
64
+
65
+ if "error" in results:
66
+ return [gr.update(visible=True, value=f"Error: {results['error']}") for _ in range(max_textboxes * 3)]
67
 
 
 
68
  updates = []
69
+ for res in results:
70
+ markdown_content = f"**{res['author']}, {res['book']}, Distance: {res['distance']}**\n\n{res['text']}"
71
  updates.append(gr.update(visible=True, value=markdown_content))
72
  updates.append(gr.update(visible=True, value="Flag", elem_id=f"flag-{len(updates)//2}"))
73
+ updates.append(gr.update(visible=False, value=res['id'])) # Hide the ID textbox
74
 
75
+ updates += [gr.update(visible=False)] * (3 * (max_textboxes - len(results)))
 
 
76
 
77
+ return updates
78
 
79
  # Initialize the CSVLogger callback for flagging
80
  callback = gr.CSVLogger()