File size: 6,582 Bytes
2d5272c
a2568fb
2d5272c
59546e8
9f80d5d
 
 
 
59546e8
96d0c1a
9f80d5d
59546e8
 
58785bd
59546e8
6bd7eb3
2d5272c
59546e8
 
 
 
 
 
2d5272c
59546e8
7f9077b
b81ad51
82b6e1c
2d5272c
 
57ddcad
 
59546e8
57ddcad
59546e8
57ddcad
 
59546e8
57ddcad
59546e8
57ddcad
2d5272c
 
7f9077b
59546e8
2d5272c
59546e8
 
2d5272c
9f80d5d
59546e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50636d8
 
59546e8
 
 
 
50636d8
59546e8
 
50636d8
 
59546e8
50636d8
59546e8
 
 
 
 
50636d8
 
59546e8
50636d8
 
59546e8
 
 
 
 
 
 
 
50636d8
59546e8
 
 
 
 
 
 
 
50636d8
 
 
 
59546e8
50636d8
59546e8
 
 
9f80d5d
 
59546e8
9f80d5d
 
59546e8
 
9f80d5d
 
59546e8
2d5272c
9f80d5d
2d5272c
59546e8
 
2d5272c
 
 
 
 
 
9f80d5d
2d5272c
 
 
 
9f80d5d
59546e8
2d5272c
9f80d5d
 
 
2d5272c
 
 
 
9f80d5d
2d5272c
 
59546e8
 
 
9f80d5d
 
 
59546e8
2d5272c
7f9077b
 
9f80d5d
2d5272c
59546e8
2d5272c
 
9f80d5d
7f9077b
2d5272c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
import streamlit as st
from openai import OpenAI
import os
import json
from dotenv import load_dotenv
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.schema import Document
from langchain_community.llms import HuggingFaceHub
from langchain.chains import RetrievalQA
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import Chroma
from tqdm import tqdm
import random

# Load environment variables
load_dotenv()

# Constants
CHUNK_SIZE = 8192
CHUNK_OVERLAP = 200
BATCH_SIZE = 100
RETRIEVER_K = 4
VECTORSTORE_PATH = "./vectorstore"

# Model information
model_links = {
    "Meta-Llama-3.1-8B": "meta-llama/Meta-Llama-3.1-8B-Instruct",
    "Mistral-7B-Instruct-v0.3": "mistralai/Mistral-7B-Instruct-v0.3",
}

model_info = {
    "Meta-Llama-3.1-8B": {
        "description": """The Llama (3.1) model is a **Large Language Model (LLM)** that's able to have question and answer interactions.
        \nIt was created by the [**Meta's AI**](https://llama.meta.com/) team and has over **8 billion parameters.**\n""",
        "logo": "llama_logo.gif",
    },
    "Mistral-7B-Instruct-v0.3": {
        "description": """The Mistral-7B-Instruct-v0.3 Large Language Model (LLM) is an instruct fine-tuned version of the Mistral-7B-v0.3.
        \nIt was created by the [**Mistral AI**](https://mistral.ai/news/announcing-mistral-7b/) team as has over **7 billion parameters.**\n""",
        "logo": "https://mistral.ai/images/logo_hubc88c4ece131b91c7cb753f40e9e1cc5_2589_256x0_resize_q97_h2_lanczos_3.webp",
    },
}

# Random dog images for error message
random_dogs = ["randomdog.jpg", "randomdog2.jpg", "randomdog3.jpg"]  # Add more as needed

# Set up embeddings
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")

def load_and_process_documents(file_path):
    """Load and process documents from a JSON file."""
    try:
        with open(file_path, "r") as file:
            data = json.load(file)

        documents = data.get("documents", [])

        if not documents:
            raise ValueError("No valid documents found in JSON file.")

        doc_objects = [
            Document(
                page_content=doc["content"],
                metadata={"title": doc["title"], "id": doc["id"]},
            )
            for doc in documents
        ]

        text_splitter = RecursiveCharacterTextSplitter(
            chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP
        )
        splits = text_splitter.split_documents(doc_objects)

        return splits
    except Exception as e:
        st.error(f"Error loading documents: {str(e)}")
        return []

def get_vectorstore(file_path):
    """Get or create a vectorstore."""
    try:
        if os.path.exists(VECTORSTORE_PATH):
            print("Loading existing vectorstore...")
            return Chroma(
                persist_directory=VECTORSTORE_PATH, embedding_function=embeddings
            )

        print("Creating new vectorstore...")
        splits = load_and_process_documents(file_path)

        vectorstore = None
        for i in tqdm(range(0, len(splits), BATCH_SIZE), desc="Processing batches"):
            batch = splits[i : i + BATCH_SIZE]
            if vectorstore is None:
                vectorstore = Chroma.from_documents(
                    documents=batch,
                    embedding=embeddings,
                    persist_directory=VECTORSTORE_PATH,
                )
            else:
                vectorstore.add_documents(documents=batch)

        vectorstore.persist()
        return vectorstore
    except Exception as e:
        st.error(f"Error creating vectorstore: {str(e)}")
        return None

@st.cache_resource(hash_funcs={"builtins.tuple": lambda _: None})
def setup_rag_pipeline(file_path, model_name, temperature):
    """Set up the RAG pipeline."""
    try:
        vectorstore = get_vectorstore(file_path)
        if vectorstore is None:
            raise ValueError("Failed to create or load vectorstore.")

        llm = HuggingFaceHub(
            repo_id=model_links[model_name],
            model_kwargs={"temperature": temperature, "max_length": 4000},
        )

        return RetrievalQA.from_chain_type(
            llm=llm,
            chain_type="stuff",
            retriever=vectorstore.as_retriever(search_kwargs={"k": RETRIEVER_K}),
            return_source_documents=True,
        )
    except Exception as e:
        st.error(f"Error setting up RAG pipeline: {str(e)}")
        return None

# Streamlit app
st.header("Liahona.AI")

# Sidebar for model selection
selected_model = st.sidebar.selectbox("Select Model", list(model_links.keys()))
st.markdown(f"_powered_ by ***:violet[{selected_model}]***")

# Temperature slider
temperature = st.sidebar.slider("Select a temperature value", 0.0, 1.0, 0.5)

# Display model info
st.sidebar.write(f"You're now chatting with **{selected_model}**")
st.sidebar.markdown(model_info[selected_model]["description"])
st.sidebar.image(model_info[selected_model]["logo"])
st.sidebar.markdown("*Generated content may be inaccurate or false.*")

# Initialize chat history
if "messages" not in st.session_state:
    st.session_state.messages = []

# Display chat messages from history
for message in st.session_state.messages:
    with st.chat_message(message["role"]):
        st.markdown(message["content"])

# Set up advanced RAG pipeline
qa_chain = setup_rag_pipeline("index_training.json", selected_model, temperature)

# Chat input
if prompt := st.chat_input("Type message here..."):
    # Display user message
    with st.chat_message("user"):
        st.markdown(prompt)
    st.session_state.messages.append({"role": "user", "content": prompt})

    # Generate and display assistant response
    with st.chat_message("assistant"):
        try:
            if qa_chain is None:
                raise ValueError("RAG pipeline is not properly set up.")
            
            result = qa_chain({"query": prompt})
            response = result["result"]
            st.write(response)

        except Exception as e:
            response = """😵‍💫 Looks like someone unplugged something!
            \n Either the model space is being updated or something is down.
            \n"""
            st.write(response)
            random_dog_pick = random.choice(random_dogs)
            st.image(random_dog_pick)
            st.write("This was the error message:")
            st.write(str(e))

    st.session_state.messages.append({"role": "assistant", "content": response})