Spaces:
Sleeping
Sleeping
import streamlit as st | |
from openai import OpenAI | |
import os | |
import json | |
from dotenv import load_dotenv | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain.schema import Document | |
from langchain_community.llms import HuggingFaceHub | |
from langchain.chains import RetrievalQA | |
from langchain_community.embeddings import HuggingFaceEmbeddings | |
from langchain_community.vectorstores import Chroma | |
from tqdm import tqdm | |
import random | |
# Load environment variables | |
load_dotenv() | |
# Constants | |
CHUNK_SIZE = 8192 | |
CHUNK_OVERLAP = 200 | |
BATCH_SIZE = 100 | |
RETRIEVER_K = 4 | |
VECTORSTORE_PATH = "./vectorstore" | |
# Model information | |
model_links = { | |
"Meta-Llama-3.1-8B": "meta-llama/Meta-Llama-3.1-8B-Instruct", | |
"Mistral-7B-Instruct-v0.3": "mistralai/Mistral-7B-Instruct-v0.3", | |
} | |
model_info = { | |
"Meta-Llama-3.1-8B": { | |
"description": """The Llama (3.1) model is a **Large Language Model (LLM)** that's able to have question and answer interactions. | |
\nIt was created by the [**Meta's AI**](https://llama.meta.com/) team and has over **8 billion parameters.**\n""", | |
"logo": "llama_logo.gif", | |
}, | |
"Mistral-7B-Instruct-v0.3": { | |
"description": """The Mistral-7B-Instruct-v0.3 Large Language Model (LLM) is an instruct fine-tuned version of the Mistral-7B-v0.3. | |
\nIt was created by the [**Mistral AI**](https://mistral.ai/news/announcing-mistral-7b/) team as has over **7 billion parameters.**\n""", | |
"logo": "https://mistral.ai/images/logo_hubc88c4ece131b91c7cb753f40e9e1cc5_2589_256x0_resize_q97_h2_lanczos_3.webp", | |
}, | |
} | |
# Random dog images for error message | |
random_dogs = ["randomdog.jpg", "randomdog2.jpg", "randomdog3.jpg"] # Add more as needed | |
# Set up embeddings | |
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") | |
def load_and_process_documents(file_path): | |
"""Load and process documents from a JSON file.""" | |
try: | |
with open(file_path, "r") as file: | |
data = json.load(file) | |
documents = data.get("documents", []) | |
if not documents: | |
raise ValueError("No valid documents found in JSON file.") | |
doc_objects = [ | |
Document( | |
page_content=doc["content"], | |
metadata={"title": doc["title"], "id": doc["id"]}, | |
) | |
for doc in documents | |
] | |
text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP | |
) | |
splits = text_splitter.split_documents(doc_objects) | |
return splits | |
except Exception as e: | |
st.error(f"Error loading documents: {str(e)}") | |
return [] | |
def get_vectorstore(file_path): | |
"""Get or create a vectorstore.""" | |
try: | |
if os.path.exists(VECTORSTORE_PATH): | |
print("Loading existing vectorstore...") | |
return Chroma( | |
persist_directory=VECTORSTORE_PATH, embedding_function=embeddings | |
) | |
print("Creating new vectorstore...") | |
splits = load_and_process_documents(file_path) | |
vectorstore = None | |
for i in tqdm(range(0, len(splits), BATCH_SIZE), desc="Processing batches"): | |
batch = splits[i : i + BATCH_SIZE] | |
if vectorstore is None: | |
vectorstore = Chroma.from_documents( | |
documents=batch, | |
embedding=embeddings, | |
persist_directory=VECTORSTORE_PATH, | |
) | |
else: | |
vectorstore.add_documents(documents=batch) | |
vectorstore.persist() | |
return vectorstore | |
except Exception as e: | |
st.error(f"Error creating vectorstore: {str(e)}") | |
return None | |
def setup_rag_pipeline(file_path, model_name, temperature): | |
"""Set up the RAG pipeline.""" | |
try: | |
vectorstore = get_vectorstore(file_path) | |
if vectorstore is None: | |
raise ValueError("Failed to create or load vectorstore.") | |
llm = HuggingFaceHub( | |
repo_id=model_links[model_name], | |
model_kwargs={"temperature": temperature, "max_length": 4000}, | |
) | |
return RetrievalQA.from_chain_type( | |
llm=llm, | |
chain_type="stuff", | |
retriever=vectorstore.as_retriever(search_kwargs={"k": RETRIEVER_K}), | |
return_source_documents=True, | |
) | |
except Exception as e: | |
st.error(f"Error setting up RAG pipeline: {str(e)}") | |
return None | |
# Streamlit app | |
st.header("Liahona.AI") | |
# Sidebar for model selection | |
selected_model = st.sidebar.selectbox("Select Model", list(model_links.keys())) | |
st.markdown(f"_powered_ by ***:violet[{selected_model}]***") | |
# Temperature slider | |
temperature = st.sidebar.slider("Select a temperature value", 0.0, 1.0, 0.5) | |
# Display model info | |
st.sidebar.write(f"You're now chatting with **{selected_model}**") | |
st.sidebar.markdown(model_info[selected_model]["description"]) | |
st.sidebar.image(model_info[selected_model]["logo"]) | |
st.sidebar.markdown("*Generated content may be inaccurate or false.*") | |
# Initialize chat history | |
if "messages" not in st.session_state: | |
st.session_state.messages = [] | |
# Display chat messages from history | |
for message in st.session_state.messages: | |
with st.chat_message(message["role"]): | |
st.markdown(message["content"]) | |
# Set up advanced RAG pipeline | |
qa_chain = setup_rag_pipeline("index_training.json", selected_model, temperature) | |
# Chat input | |
if prompt := st.chat_input("Type message here..."): | |
# Display user message | |
with st.chat_message("user"): | |
st.markdown(prompt) | |
st.session_state.messages.append({"role": "user", "content": prompt}) | |
# Generate and display assistant response | |
with st.chat_message("assistant"): | |
try: | |
if qa_chain is None: | |
raise ValueError("RAG pipeline is not properly set up.") | |
result = qa_chain({"query": prompt}) | |
response = result["result"] | |
st.write(response) | |
except Exception as e: | |
response = """😵💫 Looks like someone unplugged something! | |
\n Either the model space is being updated or something is down. | |
\n""" | |
st.write(response) | |
random_dog_pick = random.choice(random_dogs) | |
st.image(random_dog_pick) | |
st.write("This was the error message:") | |
st.write(str(e)) | |
st.session_state.messages.append({"role": "assistant", "content": response}) |