Spaces:
Sleeping
Sleeping
File size: 4,961 Bytes
2e98c79 3bbb5f4 2e98c79 10f043b 2e98c79 4295f9e 2e98c79 38d2199 2e98c79 ff486ce 2e98c79 b41c2a6 2e98c79 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import os
import requests
import gradio as gr
import chromadb
import json
import pandas as pd
from sentence_transformers import SentenceTransformer
import spaces
@spaces.GPU
def get_embeddings(text, task):
model = SentenceTransformer("Linq-AI-Research/Linq-Embed-Mistral", use_auth_token=os.getenv("HF_TOKEN"))
task = "Given a question, retrieve passages that answer the question"
prompt = f"Instruct: {task}\nQuery: {text}" # Use text here
query_embeddings = model.encode([prompt], convert_to_tensor=True) # Ensure it's a list
return query_embeddings.cpu().numpy()
# Initialize a persistent Chroma client and retrieve collection
client = chromadb.PersistentClient(path="./chroma")
collection = client.get_collection(name="phil_de")
authors_list = ["Ludwig Wittgenstein", "Sigmund Freud", "Marcus Aurelius", "Friedrich Nietzsche", "Epiktet", "Ernst Jünger", "Georg Christoph Lichtenberg", "Balthasar Gracian", "Hannah Arendt", "Erich Fromm", "Albert Camus"]
#authors_list = ["Friedrich Nietzsche", "Joscha Bach"]
def query_chroma(embeddings, authors, num_results=10):
try:
where_filter = {"author": {"$in": authors}} if authors else {}
results = collection.query(
query_embeddings=[embeddings],
n_results=num_results,
where=where_filter,
include=["documents", "metadatas", "distances"]
)
ids = results.get('ids', [[]])[0]
metadatas = results.get('metadatas', [[]])[0]
documents = results.get('documents', [[]])[0]
distances = results.get('distances', [[]])[0]
formatted_results = []
for id_, metadata, document_text, distance in zip(ids, metadatas, documents, distances):
result_dict = {
"id": id_,
"author": metadata.get('author', 'Unknown author'),
"book": metadata.get('book', 'Unknown book'),
"section": metadata.get('section', 'Unknown section'),
"title": metadata.get('title', 'Untitled'),
"text": document_text,
"distance": distance
}
formatted_results.append(result_dict)
return formatted_results
except Exception as e:
return f"Failed to query the database: {str(e)}"
# Main function
def perform_query(query, task, author, num_results):
embeddings = get_embeddings(query, task)
initial_results = query_chroma(embeddings, author, num_results)
results = [(f"{res['author']}, {res['book']}, Distance: {res['distance']}", res['text'], res['id']) for res in initial_results]
updates = []
for meta, text, id_ in results:
markdown_content = f"**{meta}**\n\n{text}"
updates.append(gr.update(visible=True, value=markdown_content))
updates.append(gr.update(visible=True, value="Flag", elem_id=f"flag-{len(updates)//2}"))
updates.append(gr.update(visible=False, value=id_)) # Hide the ID textbox
updates += [gr.update(visible=False)] * (3 * (max_textboxes - len(results) // 3))
return updates
# Initialize the CSVLogger callback for flagging
callback = gr.CSVLogger()
def flag_output(query, output_text, output_id):
callback.flag([query, output_text, output_id])
# Gradio interface
max_textboxes = 30
with gr.Blocks(css=".custom-markdown { border: 1px solid #ccc; padding: 10px; border-radius: 5px; }") as demo:
gr.Markdown("Enter your query, filter authors (default is all), click **Search** to search. Click **Flag** if a result is relevant to the query and interesting to you. Try reranking the results.")
with gr.Row():
with gr.Column():
inp = gr.Textbox(label="query", placeholder="Enter thought...")
author_inp = gr.Dropdown(label="authors", choices=authors_list, multiselect=True)
num_results_inp = gr.Number(label="number of results", value=10, step=1, minimum=1, maximum=max_textboxes)
btn = gr.Button("Search")
components = []
textboxes = []
flag_buttons = []
ids = []
for _ in range(max_textboxes):
with gr.Column() as col:
text_out = gr.Markdown(visible=False, elem_classes="custom-markdown")
flag_btn = gr.Button(value="Flag", visible=False)
id_out = gr.Textbox(visible=False)
components.extend([text_out, flag_btn, id_out])
textboxes.append(text_out)
flag_buttons.append(flag_btn)
ids.append(id_out)
callback.setup([inp] + textboxes + ids, "flagged_data_points")
btn.click(
fn=perform_query,
inputs=[inp, author_inp, num_results_inp],
outputs=components
)
for i in range(0, len(components), 3):
flag_buttons[i//3].click(
fn=flag_output,
inputs=[inp, textboxes[i//3], ids[i//3]],
outputs=[],
preprocess=False
)
demo.launch()
|