File size: 4,961 Bytes
2e98c79
 
 
 
 
 
 
 
 
3bbb5f4
2e98c79
10f043b
2e98c79
4295f9e
2e98c79
 
 
38d2199
2e98c79
 
 
ff486ce
2e98c79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b41c2a6
2e98c79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import os
import requests
import gradio as gr
import chromadb
import json
import pandas as pd

from sentence_transformers import SentenceTransformer

import spaces

@spaces.GPU
def get_embeddings(text, task):
    model = SentenceTransformer("Linq-AI-Research/Linq-Embed-Mistral", use_auth_token=os.getenv("HF_TOKEN"))
    task = "Given a question, retrieve passages that answer the question"
    prompt = f"Instruct: {task}\nQuery: {text}"  # Use text here
    query_embeddings = model.encode([prompt], convert_to_tensor=True)  # Ensure it's a list
    return query_embeddings.cpu().numpy()


# Initialize a persistent Chroma client and retrieve collection
client = chromadb.PersistentClient(path="./chroma")
collection = client.get_collection(name="phil_de")

authors_list = ["Ludwig Wittgenstein", "Sigmund Freud", "Marcus Aurelius", "Friedrich Nietzsche", "Epiktet", "Ernst Jünger", "Georg Christoph Lichtenberg", "Balthasar Gracian", "Hannah Arendt", "Erich Fromm", "Albert Camus"]
#authors_list = ["Friedrich Nietzsche", "Joscha Bach"]

def query_chroma(embeddings, authors, num_results=10):
    try:
        where_filter = {"author": {"$in": authors}} if authors else {}

        results = collection.query(
            query_embeddings=[embeddings],
            n_results=num_results,
            where=where_filter,
            include=["documents", "metadatas", "distances"]
        )

        ids = results.get('ids', [[]])[0]
        metadatas = results.get('metadatas', [[]])[0]
        documents = results.get('documents', [[]])[0]
        distances = results.get('distances', [[]])[0]

        formatted_results = []
        for id_, metadata, document_text, distance in zip(ids, metadatas, documents, distances):
            result_dict = {
                "id": id_,
                "author": metadata.get('author', 'Unknown author'),
                "book": metadata.get('book', 'Unknown book'),
                "section": metadata.get('section', 'Unknown section'),
                "title": metadata.get('title', 'Untitled'),
                "text": document_text,
                "distance": distance
            }
            formatted_results.append(result_dict)

        return formatted_results
    except Exception as e:
        return f"Failed to query the database: {str(e)}"


# Main function
def perform_query(query, task, author, num_results):
    embeddings = get_embeddings(query, task)
    initial_results = query_chroma(embeddings, author, num_results)

    results = [(f"{res['author']}, {res['book']}, Distance: {res['distance']}", res['text'], res['id']) for res in initial_results]
    
    updates = []
    for meta, text, id_ in results:
        markdown_content = f"**{meta}**\n\n{text}"
        updates.append(gr.update(visible=True, value=markdown_content))
        updates.append(gr.update(visible=True, value="Flag", elem_id=f"flag-{len(updates)//2}"))
        updates.append(gr.update(visible=False, value=id_))  # Hide the ID textbox
    
    updates += [gr.update(visible=False)] * (3 * (max_textboxes - len(results) // 3))
    
    return updates


# Initialize the CSVLogger callback for flagging 
callback = gr.CSVLogger()

def flag_output(query, output_text, output_id):
    callback.flag([query, output_text, output_id])
        
# Gradio interface
max_textboxes = 30

with gr.Blocks(css=".custom-markdown { border: 1px solid #ccc; padding: 10px; border-radius: 5px; }") as demo:
    gr.Markdown("Enter your query, filter authors (default is all), click **Search** to search. Click **Flag** if a result is relevant to the query and interesting to you. Try reranking the results.")
    with gr.Row():
        with gr.Column():
            inp = gr.Textbox(label="query", placeholder="Enter thought...")
            author_inp = gr.Dropdown(label="authors", choices=authors_list, multiselect=True)
            num_results_inp = gr.Number(label="number of results", value=10, step=1, minimum=1, maximum=max_textboxes)
            btn = gr.Button("Search")

    components = []
    textboxes = []
    flag_buttons = []
    ids = []

    for _ in range(max_textboxes):
        with gr.Column() as col:
            text_out = gr.Markdown(visible=False, elem_classes="custom-markdown")
            flag_btn = gr.Button(value="Flag", visible=False)
            id_out = gr.Textbox(visible=False)
            components.extend([text_out, flag_btn, id_out])
            textboxes.append(text_out)
            flag_buttons.append(flag_btn)
            ids.append(id_out)

    callback.setup([inp] + textboxes + ids, "flagged_data_points")
        
    btn.click(
        fn=perform_query, 
        inputs=[inp, author_inp, num_results_inp],
        outputs=components
    )
    
    for i in range(0, len(components), 3):
        flag_buttons[i//3].click(
            fn=flag_output, 
            inputs=[inp, textboxes[i//3], ids[i//3]], 
            outputs=[],
            preprocess=False
        )


demo.launch()