HonestAnnie commited on
Commit
2a5653d
1 Parent(s): 7d6132f
Files changed (1) hide show
  1. app.py +26 -15
app.py CHANGED
@@ -18,14 +18,12 @@ collection_en = client.get_collection(name="phil_en")
18
  authors_list_de = ["Ludwig Wittgenstein", "Sigmund Freud", "Marcus Aurelius", "Friedrich Nietzsche", "Epiktet", "Ernst Jünger", "Georg Christoph Lichtenberg", "Balthasar Gracian", "Hannah Arendt", "Erich Fromm", "Albert Camus"]
19
  authors_list_en = ["Friedrich Nietzsche", "Joscha Bach"]
20
 
21
- def query_chroma(collection, embeddings, authors):
22
  try:
23
  where_filter = {"author": {"$in": authors}} if authors else {}
24
-
25
- embeddings_list = embeddings[0].tolist()
26
-
27
  results = collection.query(
28
- query_embeddings=[embeddings_list],
29
  n_results=10,
30
  where=where_filter,
31
  include=["documents", "metadatas", "distances"]
@@ -51,7 +49,7 @@ def query_chroma(collection, embeddings, authors):
51
 
52
  return formatted_results
53
  except Exception as e:
54
- return {"error": str(e)}
55
 
56
  def update_authors(database):
57
  return gr.update(choices=authors_list_de if database == "German" else authors_list_en)
@@ -64,26 +62,39 @@ with gr.Blocks() as demo:
64
  author_inp = gr.Dropdown(label="Authors", choices=authors_list_de, multiselect=True)
65
  inp = gr.Textbox(label="Query", placeholder="Enter questions separated by semicolons...")
66
  btn = gr.Button("Search")
 
67
 
68
  def perform_query(queries, authors, database):
69
- queries = queries.split(';')
70
  task = "Given a question, retrieve passages that answer the question"
 
71
  embeddings = get_embeddings(queries, task)
72
  collection = collection_de if database == "German" else collection_en
73
-
74
- results = [query_chroma(collection, embedding, authors) for embedding in embeddings]
75
-
76
- for query, result in zip(queries, results):
77
- with gr.Accordion(query):
78
- markdown_contents = "\n".join(f"**{res['author']}, {res['book']}**\n\n{res['text']}" for res in result)
79
- gr.Markdown(value=markdown_contents)
80
 
81
  btn.click(
82
  perform_query,
83
  inputs=[inp, author_inp, database_inp],
84
- outputs=[]
85
  )
86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  database_inp.change(
88
  fn=lambda database: update_authors(database),
89
  inputs=[database_inp],
 
18
  authors_list_de = ["Ludwig Wittgenstein", "Sigmund Freud", "Marcus Aurelius", "Friedrich Nietzsche", "Epiktet", "Ernst Jünger", "Georg Christoph Lichtenberg", "Balthasar Gracian", "Hannah Arendt", "Erich Fromm", "Albert Camus"]
19
  authors_list_en = ["Friedrich Nietzsche", "Joscha Bach"]
20
 
21
+ def query_chroma(collection, embedding, authors):
22
  try:
23
  where_filter = {"author": {"$in": authors}} if authors else {}
24
+ # Directly use the embedding provided, already in list format suitable for the query
 
 
25
  results = collection.query(
26
+ query_embeddings=[embedding.tolist()], # Ensure embedding is properly formatted
27
  n_results=10,
28
  where=where_filter,
29
  include=["documents", "metadatas", "distances"]
 
49
 
50
  return formatted_results
51
  except Exception as e:
52
+ return [{"error": str(e)}]
53
 
54
  def update_authors(database):
55
  return gr.update(choices=authors_list_de if database == "German" else authors_list_en)
 
62
  author_inp = gr.Dropdown(label="Authors", choices=authors_list_de, multiselect=True)
63
  inp = gr.Textbox(label="Query", placeholder="Enter questions separated by semicolons...")
64
  btn = gr.Button("Search")
65
+ results = gr.State() # Store results in a State component
66
 
67
  def perform_query(queries, authors, database):
 
68
  task = "Given a question, retrieve passages that answer the question"
69
+ queries = queries.split(';')
70
  embeddings = get_embeddings(queries, task)
71
  collection = collection_de if database == "German" else collection_en
72
+ results_data = []
73
+ for query, embedding in zip(queries, embeddings):
74
+ res = query_chroma(collection, embedding, authors)
75
+ results_data.append((query, res))
76
+ return results_data
 
 
77
 
78
  btn.click(
79
  perform_query,
80
  inputs=[inp, author_inp, database_inp],
81
+ outputs=[results]
82
  )
83
 
84
+ @gr.render(inputs=[results])
85
+ def display_accordion(data):
86
+ output_blocks = []
87
+ for query, res in data:
88
+ with gr.Accordion(query) as acc:
89
+ if not res:
90
+ markdown_contents = "No results found."
91
+ elif "error" in res[0]:
92
+ markdown_contents = f"Error retrieving data: {res[0]['error']}"
93
+ else:
94
+ markdown_contents = "\n".join(f"**{r['author']}, {r['book']}**\n\n{r['text']}" for r in res)
95
+ gr.Markdown(markdown_contents)
96
+
97
+
98
  database_inp.change(
99
  fn=lambda database: update_authors(database),
100
  inputs=[database_inp],