import os import random from openai import OpenAI import streamlit as st from dotenv import load_dotenv from huggingface_hub import get_token from langchain_huggingface import HuggingFaceEndpoint from langchain.indexes import VectorstoreIndexCreator from langchain_community.document_loaders.hugging_face_dataset import HuggingFaceDatasetLoader from langchain_huggingface.embeddings.huggingface_endpoint import HuggingFaceEndpointEmbeddings from langchain.chains import RetrievalQA from langchain_community.vectorstores import FAISS # Load environment variables load_dotenv() api_key=os.environ.get('API_KEY') get_token() # Constants MAX_TOKENS = 4000 DEFAULT_TEMPERATURE = 0.5 # Initialize the OpenAI client client = OpenAI( base_url="https://api-inference.huggingface.co/v1", api_key=api_key ) # Create supported models model_links = { "Meta-Llama-3.1-8B": "meta-llama/Meta-Llama-3.1-8B-Instruct", "Mistral-7B-Instruct-v0.3": "mistralai/Mistral-7B-Instruct-v0.3", "Gemma-2-27b-it": "google/gemma-2-27b-it", "Falcon-7b-Instruct": "tiiuae/falcon-7b-instruct", } # Load documents and set up RAG pipeline @st.cache_resource def setup_rag_pipeline(): loader = HuggingFaceDatasetLoader( path='Atreyu4EVR/General-BYUI-Data', page_content_column='content' ) documents = loader.load() hf_embeddings = HuggingFaceEndpointEmbeddings( model="sentence-transformers/all-MiniLM-L12-v2", task="feature-extraction", huggingfacehub_api_token=api_key ) vector_store = FAISS.from_documents(documents, hf_embeddings) retriever = vector_store.as_retriever() return retriever def reset_conversation(): st.session_state.conversation = [] st.session_state.messages = [] def main(): st.header('Multi-Models with RAG') # Sidebar for model selection and temperature selected_model = st.sidebar.selectbox("Select Model", list(model_links.keys())) temperature = st.sidebar.slider('Select a temperature value', 0.0, 1.0, DEFAULT_TEMPERATURE) st.sidebar.button('Reset Chat', on_click=reset_conversation) if "prev_option" not in st.session_state: st.session_state.prev_option = selected_model if st.session_state.prev_option != selected_model: st.session_state.messages = [] st.session_state.prev_option = selected_model reset_conversation() st.markdown(f'_powered_ by ***:violet[{selected_model}]***') # Display model info st.sidebar.write(f"You're now chatting with **{selected_model}**") st.sidebar.markdown("*Generated content may be inaccurate or false.*") # Initialize chat history if "messages" not in st.session_state: st.session_state.messages = [] # Display chat messages from history on app rerun for message in st.session_state.messages: with st.chat_message(message["role"]): st.markdown(message["content"]) # Set up RAG pipeline retriever = setup_rag_pipeline() # Chat input and response if prompt := st.chat_input("Type message here..."): process_user_input(client, prompt, selected_model, temperature, retriever) def process_user_input(client, prompt, selected_model, temperature, retriever): # Display user message with st.chat_message("user"): st.markdown(prompt) # Retrieve relevant documents relevant_docs = retriever.get_relevant_documents(prompt) context = "\n".join([doc.page_content for doc in relevant_docs]) # Prepare messages with context messages = [ {"role": "system", "content": f"You are an AI assistant. Use the following context to answer the user's question: {context}"}, {"role": "user", "content": prompt} ] st.session_state.messages.extend(messages) # Generate and display assistant response with st.chat_message("assistant"): try: stream = client.chat.completions.create( model=model_links[selected_model], messages=[ {"role": m["role"], "content": m["content"]} for m in st.session_state.messages ], temperature=temperature, stream=True, max_tokens=MAX_TOKENS, ) response = st.write_stream(stream) except Exception as e: handle_error(e) return st.session_state.messages.append({"role": "assistant", "content": response}) def handle_error(error): response = """😵‍💫 Looks like someone unplugged something! \n Either the model space is being updated or something is down.""" st.write(response) random_dog_pick = random.choice(["broken_llama3.jpeg"]) st.image(random_dog_pick) st.write("This was the error message:") st.write(str(error)) if __name__ == "__main__": main()