Atreyu4EVR's picture
Update app.py
ef211f7 verified
raw
history blame
5.39 kB
import os
import random
from openai import OpenAI
import streamlit as st
from dotenv import load_dotenv
from huggingface_hub import get_token
from langchain_huggingface import HuggingFaceEndpoint
from langchain.indexes import VectorstoreIndexCreator
from langchain_community.document_loaders.hugging_face_dataset import HuggingFaceDatasetLoader
from langchain_huggingface.embeddings.huggingface_endpoint import HuggingFaceEndpointEmbeddings
from langchain.chains import RetrievalQA
from langchain_community.vectorstores import FAISS
# Load environment variables
load_dotenv()
api_key=os.environ.get('API_KEY')
get_token()
# Constants
MAX_TOKENS = 4000
DEFAULT_TEMPERATURE = 0.75
# Initialize the OpenAI client
client = OpenAI(
base_url="https://api-inference.huggingface.co/v1",
api_key=api_key
)
# Create supported models
model_links = {
"Meta-Llama-3.1-8B": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"Mistral-7B-Instruct-v0.3": "mistralai/Mistral-7B-Instruct-v0.3",
"Gemma-2-27b-it": "google/gemma-2-27b-it",
"Falcon-7b-Instruct": "tiiuae/falcon-7b-instruct",
}
# Load documents and set up RAG pipeline
@st.cache_resource
def setup_rag_pipeline():
loader = HuggingFaceDatasetLoader(
path='Atreyu4EVR/General-BYUI-Data',
page_content_column='content'
)
documents = loader.load()
hf_embeddings = HuggingFaceEndpointEmbeddings(
model="sentence-transformers/all-MiniLM-L12-v2",
task="feature-extraction",
huggingfacehub_api_token=api_key
)
vector_store = FAISS.from_documents(documents, hf_embeddings)
retriever = vector_store.as_retriever()
return retriever
def reset_conversation():
st.session_state.visible_messages = []
st.session_state.full_context = []
def main():
st.header('Multi-Models with RAG')
# Sidebar for model selection and temperature
selected_model = st.sidebar.selectbox("Select Model", list(model_links.keys()))
temperature = st.sidebar.slider('Select a temperature value', 0.0, 1.0, DEFAULT_TEMPERATURE)
st.sidebar.button('Reset Chat', on_click=reset_conversation)
if "prev_option" not in st.session_state:
st.session_state.prev_option = selected_model
if st.session_state.prev_option != selected_model:
st.session_state.visible_messages = []
st.session_state.full_context = []
st.session_state.prev_option = selected_model
st.markdown(f'_powered_ by ***:violet[{selected_model}]***')
# Display model info
st.sidebar.write(f"You're now chatting with **{selected_model}**")
st.sidebar.markdown("*Generated content may be inaccurate or false.*")
# Initialize chat history
if "visible_messages" not in st.session_state:
st.session_state.visible_messages = []
if "full_context" not in st.session_state:
st.session_state.full_context = []
# Display chat messages from history on app rerun
for message in st.session_state.visible_messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
# Set up RAG pipeline
retriever = setup_rag_pipeline()
# Chat input and response
if prompt := st.chat_input("Type message here..."):
process_user_input(client, prompt, selected_model, temperature, retriever)
def process_user_input(client, prompt, selected_model, temperature, retriever):
# Display user message
with st.chat_message("user"):
st.markdown(prompt)
st.session_state.visible_messages.append({"role": "user", "content": prompt})
# Retrieve relevant documents
relevant_docs = retriever.get_relevant_documents(prompt)
context = "\n".join([doc.page_content for doc in relevant_docs])
# Prepare full context with system message and retrieved context
full_context = [
{"role": "system", "content": f"You are 'Liahona' an AI chatbot for Brigham Young University-Idaho (BYU-I) students, employees, staff and administrators. Your role is to use the retreived content to form the best response possible to the user's question. Be thorough, helpful, and friendly. Here is content that closely matches the question: {context}"},
*st.session_state.full_context,
{"role": "user", "content": prompt}
]
# Update full context in session state
st.session_state.full_context = full_context
# Generate and display assistant response
with st.chat_message("assistant"):
try:
stream = client.chat.completions.create(
model=model_links[selected_model],
messages=full_context,
temperature=temperature,
stream=True,
max_tokens=MAX_TOKENS,
)
response = st.write_stream(stream)
except Exception as e:
handle_error(e)
return
# Update visible messages and full context
st.session_state.visible_messages.append({"role": "assistant", "content": response})
def handle_error(error):
response = """πŸ˜΅β€πŸ’« Looks like someone unplugged something!
\n Either the model space is being updated or something is down."""
st.write(response)
random_dog_pick = random.choice(["broken_llama3.jpeg"])
st.image(random_dog_pick)
st.write("This was the error message:")
st.write(str(error))
if __name__ == "__main__":
main()