File size: 2,491 Bytes
6ccb1e9
 
703992e
 
6ccb1e9
28c69cd
 
bac38da
 
f61a16c
bac38da
3aa419b
6ccb1e9
bac38da
d542df6
703992e
ab7881a
703992e
392455e
ab7881a
 
 
6ccb1e9
ab7881a
6ccb1e9
 
 
 
 
 
 
 
703992e
6ccb1e9
 
703992e
6ccb1e9
 
 
 
 
 
703992e
 
ab7881a
 
703992e
 
6ccb1e9
 
 
703992e
 
6ccb1e9
 
 
 
 
 
 
 
 
 
703992e
6ccb1e9
 
 
 
 
5c4abd8
6ccb1e9
 
 
 
 
 
 
 
 
 
 
296274b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import gradio as gr
from huggingface_hub import InferenceClient
import chromadb
from chromadb.config import Settings

from chromadb import PersistentClient

# Initialize the inference client with model
inference_client = InferenceClient(model="unsloth/Llama-3.2-3B-Instruct") 

# path to the ChromaDB directory
client_db = PersistentClient(path="./chromadb_directory/chromadb_file")

# Load collection
collection = client_db.get_collection("my_collection")

# Function to retrieve documents from ChromaDB, ensuring results are strings
def retrieve_from_chromadb(query):
    results = collection.query(query_texts=query, n_results=5)  # Adjust n_results as needed
    # Ensure each document is a string
    documents = [str(doc) for doc in results['documents']]
    return documents

# Respond function for the chatbot
def respond(
    message,
    history: list[tuple[str, str]],
    system_message,
    max_tokens,
    temperature,
    top_p,
):
    # Prepare messages for the model
    messages = [{"role": "system", "content": system_message}]

    # Add conversation history
    for val in history:
        if val[0]:
            messages.append({"role": "user", "content": val[0]})
        if val[1]:
            messages.append({"role": "assistant", "content": val[1]})

    # Retrieve relevant documents from ChromaDB
    retrieved_docs = retrieve_from_chromadb(message)
    
    # Join the documents to create a context for the user query
    context = "\n".join(retrieved_docs) + "\nUser: " + message
    messages.append({"role": "user", "content": context})

    response = ""

    # Generate response using the Inference Client
    for message in inference_client.chat_completion(
        messages,
        max_tokens=max_tokens,
        stream=True,
        temperature=temperature,
        top_p=top_p,
    ):
        token = message.choices[0].delta.content
        response += token
        yield response

# Gradio Chat Interface
demo = gr.ChatInterface(
    respond,
    additional_inputs=[
        gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
        gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(
            minimum=0.1,
            maximum=1.0,
            value=0.95,
            step=0.05,
            label="Top-p (nucleus sampling)",
        ),
    ],
)

if __name__ == "__main__":
    demo.launch()