Spaces:
Sleeping
Sleeping
import os | |
import requests | |
import gradio as gr | |
import chromadb | |
import json | |
import pandas as pd | |
from sentence_transformers import SentenceTransformer | |
import spaces | |
def get_embeddings(text, task): | |
model = SentenceTransformer("Linq-AI-Research/Linq-Embed-Mistral", use_auth_token=os.getenv("HF_TOKEN")) | |
task = "Given a question, retrieve passages that answer the question" | |
prompt = f"Instruct: {task}\nQuery: {text}" # Use text here | |
query_embeddings = model.encode([prompt], convert_to_tensor=True) # Ensure it's a list | |
return query_embeddings.cpu().numpy() | |
# Initialize a persistent Chroma client and retrieve collection | |
client = chromadb.PersistentClient(path="./chroma") | |
collection = client.get_collection(name="phil_de") | |
authors_list = ["Ludwig Wittgenstein", "Sigmund Freud", "Marcus Aurelius", "Friedrich Nietzsche", "Epiktet", "Ernst Jünger", "Georg Christoph Lichtenberg", "Balthasar Gracian", "Hannah Arendt", "Erich Fromm", "Albert Camus"] | |
#authors_list = ["Friedrich Nietzsche", "Joscha Bach"] | |
def query_chroma(embeddings, authors, num_results=10): | |
try: | |
where_filter = {"author": {"$in": authors}} if authors else {} | |
results = collection.query( | |
query_embeddings=[embeddings], | |
n_results=num_results, | |
where=where_filter, | |
include=["documents", "metadatas", "distances"] | |
) | |
ids = results.get('ids', [[]])[0] | |
metadatas = results.get('metadatas', [[]])[0] | |
documents = results.get('documents', [[]])[0] | |
distances = results.get('distances', [[]])[0] | |
formatted_results = [] | |
for id_, metadata, document_text, distance in zip(ids, metadatas, documents, distances): | |
result_dict = { | |
"id": id_, | |
"author": metadata.get('author', 'Unknown author'), | |
"book": metadata.get('book', 'Unknown book'), | |
"section": metadata.get('section', 'Unknown section'), | |
"title": metadata.get('title', 'Untitled'), | |
"text": document_text, | |
"distance": distance | |
} | |
formatted_results.append(result_dict) | |
return formatted_results | |
except Exception as e: | |
return f"Failed to query the database: {str(e)}" | |
# Main function | |
def perform_query(query, task, author, num_results): | |
embeddings = get_embeddings(query, task) | |
initial_results = query_chroma(embeddings, author, num_results) | |
results = [(f"{res['author']}, {res['book']}, Distance: {res['distance']}", res['text'], res['id']) for res in initial_results] | |
updates = [] | |
for meta, text, id_ in results: | |
markdown_content = f"**{meta}**\n\n{text}" | |
updates.append(gr.update(visible=True, value=markdown_content)) | |
updates.append(gr.update(visible=True, value="Flag", elem_id=f"flag-{len(updates)//2}")) | |
updates.append(gr.update(visible=False, value=id_)) # Hide the ID textbox | |
updates += [gr.update(visible=False)] * (3 * (max_textboxes - len(results) // 3)) | |
return updates | |
# Initialize the CSVLogger callback for flagging | |
callback = gr.CSVLogger() | |
def flag_output(query, output_text, output_id): | |
callback.flag([query, output_text, output_id]) | |
# Gradio interface | |
max_textboxes = 30 | |
with gr.Blocks(css=".custom-markdown { border: 1px solid #ccc; padding: 10px; border-radius: 5px; }") as demo: | |
gr.Markdown("Enter your query, filter authors (default is all), click **Search** to search. Click **Flag** if a result is relevant to the query and interesting to you. Try reranking the results.") | |
with gr.Row(): | |
with gr.Column(): | |
inp = gr.Textbox(label="query", placeholder="Enter thought...") | |
author_inp = gr.Dropdown(label="authors", choices=authors_list, multiselect=True) | |
num_results_inp = gr.Number(label="number of results", value=10, step=1, minimum=1, maximum=max_textboxes) | |
btn = gr.Button("Search") | |
components = [] | |
textboxes = [] | |
flag_buttons = [] | |
ids = [] | |
for _ in range(max_textboxes): | |
with gr.Column() as col: | |
text_out = gr.Markdown(visible=False, elem_classes="custom-markdown") | |
flag_btn = gr.Button(value="Flag", visible=False) | |
id_out = gr.Textbox(visible=False) | |
components.extend([text_out, flag_btn, id_out]) | |
textboxes.append(text_out) | |
flag_buttons.append(flag_btn) | |
ids.append(id_out) | |
callback.setup([inp] + textboxes + ids, "flagged_data_points") | |
btn.click( | |
fn=perform_query, | |
inputs=[inp, author_inp, num_results_inp], | |
outputs=components | |
) | |
for i in range(0, len(components), 3): | |
flag_buttons[i//3].click( | |
fn=flag_output, | |
inputs=[inp, textboxes[i//3], ids[i//3]], | |
outputs=[], | |
preprocess=False | |
) | |
demo.launch() | |