Spaces:
Sleeping
Sleeping
import os | |
import random | |
from openai import OpenAI | |
import streamlit as st | |
from dotenv import load_dotenv | |
from huggingface_hub import get_token | |
from langchain_huggingface import HuggingFaceEndpoint | |
from langchain.indexes import VectorstoreIndexCreator | |
from langchain_community.document_loaders.hugging_face_dataset import HuggingFaceDatasetLoader | |
from langchain_huggingface.embeddings.huggingface_endpoint import HuggingFaceEndpointEmbeddings | |
from langchain.chains import RetrievalQA | |
from langchain_community.vectorstores import FAISS | |
# Load environment variables | |
load_dotenv() | |
api_key=os.environ.get('API_KEY') | |
get_token() | |
# Constants | |
MAX_TOKENS = 4000 | |
DEFAULT_TEMPERATURE = 0.75 | |
# Initialize the OpenAI client | |
client = OpenAI( | |
base_url="https://api-inference.huggingface.co/v1", | |
api_key=api_key | |
) | |
# Create supported models | |
model_links = { | |
"Meta-Llama-3.1-8B": "meta-llama/Meta-Llama-3.1-8B-Instruct", | |
"Mistral-7B-Instruct-v0.3": "mistralai/Mistral-7B-Instruct-v0.3", | |
"Gemma-2-27b-it": "google/gemma-2-27b-it", | |
"Falcon-7b-Instruct": "tiiuae/falcon-7b-instruct", | |
} | |
# Load documents and set up RAG pipeline | |
def setup_rag_pipeline(): | |
loader = HuggingFaceDatasetLoader( | |
path='Atreyu4EVR/General-BYUI-Data', | |
page_content_column='content' | |
) | |
documents = loader.load() | |
hf_embeddings = HuggingFaceEndpointEmbeddings( | |
model="sentence-transformers/all-MiniLM-L12-v2", | |
task="feature-extraction", | |
huggingfacehub_api_token=api_key | |
) | |
vector_store = FAISS.from_documents(documents, hf_embeddings) | |
retriever = vector_store.as_retriever() | |
return retriever | |
def reset_conversation(): | |
st.session_state.visible_messages = [] | |
st.session_state.full_context = [] | |
def main(): | |
st.header('Multi-Models with RAG') | |
# Sidebar for model selection and temperature | |
selected_model = st.sidebar.selectbox("Select Model", list(model_links.keys())) | |
temperature = st.sidebar.slider('Select a temperature value', 0.0, 1.0, DEFAULT_TEMPERATURE) | |
st.sidebar.button('Reset Chat', on_click=reset_conversation) | |
if "prev_option" not in st.session_state: | |
st.session_state.prev_option = selected_model | |
if st.session_state.prev_option != selected_model: | |
st.session_state.visible_messages = [] | |
st.session_state.full_context = [] | |
st.session_state.prev_option = selected_model | |
st.markdown(f'_powered_ by ***:violet[{selected_model}]***') | |
# Display model info | |
st.sidebar.write(f"You're now chatting with **{selected_model}**") | |
st.sidebar.markdown("*Generated content may be inaccurate or false.*") | |
# Initialize chat history | |
if "visible_messages" not in st.session_state: | |
st.session_state.visible_messages = [] | |
if "full_context" not in st.session_state: | |
st.session_state.full_context = [] | |
# Display chat messages from history on app rerun | |
for message in st.session_state.visible_messages: | |
with st.chat_message(message["role"]): | |
st.markdown(message["content"]) | |
# Set up RAG pipeline | |
retriever = setup_rag_pipeline() | |
# Chat input and response | |
if prompt := st.chat_input("Type message here..."): | |
process_user_input(client, prompt, selected_model, temperature, retriever) | |
def process_user_input(client, prompt, selected_model, temperature, retriever): | |
# Display user message | |
with st.chat_message("user"): | |
st.markdown(prompt) | |
st.session_state.visible_messages.append({"role": "user", "content": prompt}) | |
# Retrieve relevant documents | |
relevant_docs = retriever.get_relevant_documents(prompt) | |
context = "\n".join([doc.page_content for doc in relevant_docs]) | |
# Prepare full context with system message and retrieved context | |
full_context = [ | |
{"role": "system", "content": f"You are 'Liahona' an AI chatbot for Brigham Young University-Idaho (BYU-I) students, employees, staff and administrators. Your role is to use the retreived content to form the best response possible to the user's question. Be thorough, helpful, and friendly. Here is content that closely matches the question: {context}"}, | |
*st.session_state.full_context, | |
{"role": "user", "content": prompt} | |
] | |
# Update full context in session state | |
st.session_state.full_context = full_context | |
# Generate and display assistant response | |
with st.chat_message("assistant"): | |
try: | |
stream = client.chat.completions.create( | |
model=model_links[selected_model], | |
messages=full_context, | |
temperature=temperature, | |
stream=True, | |
max_tokens=MAX_TOKENS, | |
) | |
response = st.write_stream(stream) | |
except Exception as e: | |
handle_error(e) | |
return | |
# Update visible messages and full context | |
st.session_state.visible_messages.append({"role": "assistant", "content": response}) | |
def handle_error(error): | |
response = """π΅βπ« Looks like someone unplugged something! | |
\n Either the model space is being updated or something is down.""" | |
st.write(response) | |
random_dog_pick = random.choice(["broken_llama3.jpeg"]) | |
st.image(random_dog_pick) | |
st.write("This was the error message:") | |
st.write(str(error)) | |
if __name__ == "__main__": | |
main() |