svghenfpkob / app.py
HonestAnnie's picture
trhhh
2ec7158
raw
history blame
4.93 kB
import os
import gradio as gr
import chromadb
from sentence_transformers import SentenceTransformer
import spaces
@spaces.GPU
def get_embeddings(query, task):
model = SentenceTransformer("Linq-AI-Research/Linq-Embed-Mistral", use_auth_token=os.getenv("HF_TOKEN"))
task = "Given a question, retrieve passages that answer the question"
prompt = f"Instruct: {task}\nQuery: {query}"
query_embeddings = model.encode([prompt])
return query_embeddings
# Initialize a persistent Chroma client and retrieve collection
client = chromadb.PersistentClient(path="./chroma")
collection = client.get_collection(name="phil_de")
authors_list = ["Ludwig Wittgenstein", "Sigmund Freud", "Marcus Aurelius", "Friedrich Nietzsche", "Epiktet", "Ernst Jünger", "Georg Christoph Lichtenberg", "Balthasar Gracian", "Hannah Arendt", "Erich Fromm", "Albert Camus"]
#authors_list = ["Friedrich Nietzsche", "Joscha Bach"]
def query_chroma(embeddings, authors, num_results=10):
try:
where_filter = {"author": {"$in": authors}} if authors else {}
results = collection.query(
query_embeddings=[embeddings],
n_results=num_results,
where=where_filter,
include=["documents", "metadatas", "distances"]
)
ids = results.get('ids', [[]])[0]
metadatas = results.get('metadatas', [[]])[0]
documents = results.get('documents', [[]])[0]
distances = results.get('distances', [[]])[0]
formatted_results = []
for id_, metadata, document_text, distance in zip(ids, metadatas, documents, distances):
result_dict = {
"id": id_,
"author": metadata.get('author', 'Unknown author'),
"book": metadata.get('book', 'Unknown book'),
"section": metadata.get('section', 'Unknown section'),
"title": metadata.get('title', 'Untitled'),
"text": document_text,
"distance": distance
}
formatted_results.append(result_dict)
return formatted_results
except Exception as e:
return {"error": str(e)}
# Main function
def perform_query(query, authors, num_results):
task = "Given a question, retrieve passages that answer the question"
embeddings = get_embeddings(query, task)
results = query_chroma(embeddings, authors, num_results)
if "error" in results:
return [gr.update(visible=True, value=f"Error: {results['error']}") for _ in range(max_textboxes * 3)]
updates = []
for res in results:
markdown_content = f"**{res['author']}, {res['book']}, Distance: {res['distance']}**\n\n{res['text']}"
updates.append(gr.update(visible=True, value=markdown_content))
updates.append(gr.update(visible=True, value="Flag", elem_id=f"flag-{len(updates)//2}"))
updates.append(gr.update(visible=False, value=res['id'])) # Hide the ID textbox
updates += [gr.update(visible=False)] * (3 * (max_textboxes - len(results)))
return updates
# Initialize the CSVLogger callback for flagging
callback = gr.CSVLogger()
def flag_output(query, output_text, output_id):
callback.flag([query, output_text, output_id])
# Gradio interface
max_textboxes = 30
with gr.Blocks(css=".custom-markdown { border: 1px solid #ccc; padding: 10px; border-radius: 5px; }") as demo:
gr.Markdown("Enter your query, filter authors (default is all), click **Search** to search. Click **Flag** if a result is relevant to the query and interesting to you. Try reranking the results.")
with gr.Row():
with gr.Column():
inp = gr.Textbox(label="query", placeholder="Enter thought...")
author_inp = gr.Dropdown(label="authors", choices=authors_list, multiselect=True)
num_results_inp = gr.Number(label="number of results", value=10, step=1, minimum=1, maximum=max_textboxes)
btn = gr.Button("Search")
components = []
textboxes = []
flag_buttons = []
ids = []
for _ in range(max_textboxes):
with gr.Column() as col:
text_out = gr.Markdown(visible=False, elem_classes="custom-markdown")
flag_btn = gr.Button(value="Flag", visible=False)
id_out = gr.Textbox(visible=False)
components.extend([text_out, flag_btn, id_out])
textboxes.append(text_out)
flag_buttons.append(flag_btn)
ids.append(id_out)
callback.setup([inp] + textboxes + ids, "flagged_data_points")
btn.click(
fn=perform_query,
inputs=[inp, author_inp, num_results_inp],
outputs=components
)
for i in range(0, len(components), 3):
flag_buttons[i//3].click(
fn=flag_output,
inputs=[inp, textboxes[i//3], ids[i//3]],
outputs=[],
preprocess=False
)
demo.launch()