import os import gradio as gr import chromadb from sentence_transformers import SentenceTransformer import spaces @spaces.GPU def get_embeddings(queries, task): model = SentenceTransformer("Linq-AI-Research/Linq-Embed-Mistral", use_auth_token=os.getenv("HF_TOKEN")) prompts = [f"Instruct: {task}\nQuery: {query}" for query in queries] query_embeddings = model.encode(prompts) return query_embeddings # Initialize a persistent Chroma client and retrieve collection client = chromadb.PersistentClient(path="./chroma") collection_de = client.get_collection(name="phil_de") collection_en = client.get_collection(name="phil_en") authors_list_de = ["Ludwig Wittgenstein", "Sigmund Freud", "Marcus Aurelius", "Friedrich Nietzsche", "Epiktet", "Ernst Jünger", "Georg Christoph Lichtenberg", "Balthasar Gracian", "Hannah Arendt", "Erich Fromm", "Albert Camus"] authors_list_en = ["Friedrich Nietzsche", "Joscha Bach"] def query_chroma(collection, embedding, authors): try: where_filter = {"author": {"$in": authors}} if authors else {} # Directly use the embedding provided, already in list format suitable for the query results = collection.query( query_embeddings=[embedding.tolist()], # Ensure embedding is properly formatted n_results=10, where=where_filter, include=["documents", "metadatas", "distances"] ) ids = results.get('ids', [[]])[0] metadatas = results.get('metadatas', [[]])[0] documents = results.get('documents', [[]])[0] distances = results.get('distances', [[]])[0] formatted_results = [] for id_, metadata, document_text, distance in zip(ids, metadatas, documents, distances): result_dict = { "id": id_, "author": metadata.get('author', 'Unknown author'), "book": metadata.get('book', 'Unknown book'), "section": metadata.get('section', 'Unknown section'), "title": metadata.get('title', 'Untitled'), "text": document_text, "distance": distance } formatted_results.append(result_dict) return formatted_results except Exception as e: return [{"error": str(e)}] def update_authors(database): return gr.update(choices=authors_list_de if database == "German" else authors_list_en) with gr.Blocks() as demo: gr.Markdown("Enter your query, filter authors (default is all), click **Search** to search.") database_inp = gr.Dropdown(label="Database", choices=["English", "German"], value="German") author_inp = gr.Dropdown(label="Authors", choices=authors_list_de, multiselect=True) inp = gr.Textbox(label="Query", placeholder="Enter questions separated by semicolons...") btn = gr.Button("Search") results = gr.State() # Store results in a State component def perform_query(queries, authors, database): task = "Given a question, retrieve passages that answer the question" queries = queries.split(';') embeddings = get_embeddings(queries, task) collection = collection_de if database == "German" else collection_en results_data = [] for query, embedding in zip(queries, embeddings): res = query_chroma(collection, embedding, authors) results_data.append((query, res)) return results_data btn.click( perform_query, inputs=[inp, author_inp, database_inp], outputs=[results] ) @gr.render(inputs=[results]) def display_accordion(data): output_blocks = [] for query, res in data: with gr.Accordion(query) as acc: if not res: markdown_contents = "No results found." elif "error" in res[0]: markdown_contents = f"Error retrieving data: {res[0]['error']}" else: markdown_contents = "\n".join(f"**{r['author']}, {r['book']}**\n\n{r['text']}" for r in res) gr.Markdown(markdown_contents) database_inp.change( fn=lambda database: update_authors(database), inputs=[database_inp], outputs=[author_inp] ) demo.launch()