File size: 3,313 Bytes
f7c4af0
22cfb6e
43a8cd8
 
 
 
 
8838db8
43a8cd8
 
 
 
 
 
 
 
a72e07a
22cfb6e
43a8cd8
9cc7e25
 
43a8cd8
9cc7e25
 
 
43a8cd8
 
 
 
5ecd97e
9cc7e25
43a8cd8
 
 
9cc7e25
43a8cd8
9cc7e25
8324d73
43a8cd8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8324d73
43a8cd8
 
 
 
99cdb28
 
 
43a8cd8
 
 
 
 
 
 
99cdb28
 
43a8cd8
 
 
 
 
 
 
 
8324d73
43a8cd8
 
3d56935
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import os
import gradio as gr
import copy
from llama_cpp import Llama
from huggingface_hub import hf_hub_download
import chromadb
from sentence_transformers import SentenceTransformer

# Initialize the Llama model
llm = Llama(
    model_path=hf_hub_download(
        repo_id="microsoft/Phi-3-mini-4k-instruct-gguf",
        filename="Phi-3-mini-4k-instruct-q4.gguf",
    ),
    n_ctx=2048,
    n_gpu_layers=50,  # Adjust based on your VRAM
)

# Initialize ChromaDB Vector Store
class VectorStore:
    def __init__(self, collection_name):
        self.embedding_model = SentenceTransformer('sentence-transformers/multi-qa-MiniLM-L6-cos-v1')
        self.chroma_client = chromadb.Client()
        self.collection = self.chroma_client.create_collection(name=collection_name)

    def populate_vectors(self, texts, ids):
        embeddings = self.embedding_model.encode(texts, batch_size=32).tolist()
        for text, embedding, doc_id in zip(texts, embeddings, ids):
            self.collection.add(embeddings=[embedding], documents=[text], ids=[doc_id])

    def search_context(self, query, n_results=1):
        query_embedding = self.embedding_model.encode([query]).tolist()
        results = self.collection.query(query_embeddings=query_embedding, n_results=n_results)
        return results['documents']

# Example initialization (assuming you've already populated the vector store)
vector_store = VectorStore("embedding_vector")

# Populate with your data if not already done
# vector_store.populate_vectors(your_texts, your_ids)

def generate_text(
    message,
    history: list[tuple[str, str]],
    system_message,
    max_tokens,
    temperature,
    top_p,
):
    # Retrieve context from vector store
    context_results = vector_store.search_context(message, n_results=1)
    context = context_results[0] if context_results else ""

    input_prompt = f"[INST] <<SYS>>\n{system_message}\n<</SYS>>\n\n {context}\n"
    for interaction in history:
        input_prompt += f"{interaction[0]} [/INST] {interaction[1]} </s><s> [INST] "
    input_prompt += f"{message} [/INST] "

    temp = ""
    output = llm(
        input_prompt,
        temperature=temperature,
        top_p=top_p,
        top_k=40,
        repeat_penalty=1.1,
        max_tokens=max_tokens,
        stop=["", " \n", "ASSISTANT:", "USER:", "SYSTEM:"],
        stream=True,
    )
    for out in output:
        temp += out["choices"][0]["text"]
        yield temp

# Define the Gradio interface
demo = gr.ChatInterface(
    generate_text,
    title="llama-cpp-python on GPU with ChromaDB",
    description="Running LLM with context retrieval from ChromaDB",
    examples=[
        ["I have leftover rice, what can I make out of it?"],
        ["Can I make lunch for two people with this?"],
    ],
    cache_examples=False,
    retry_btn=None,
    undo_btn="Delete Previous",
    clear_btn="Clear",
    additional_inputs=[
        gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
        gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
    ],
)

if __name__ == "__main__":
    demo.launch()