Atreyu4EVR commited on
Commit
8b05694
Β·
verified Β·
1 Parent(s): ef211f7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -91
app.py CHANGED
@@ -1,152 +1,100 @@
 
1
  import os
2
- import random
3
  from openai import OpenAI
4
- import streamlit as st
 
5
  from dotenv import load_dotenv
6
- from huggingface_hub import get_token
7
- from langchain_huggingface import HuggingFaceEndpoint
8
- from langchain.indexes import VectorstoreIndexCreator
9
- from langchain_community.document_loaders.hugging_face_dataset import HuggingFaceDatasetLoader
10
- from langchain_huggingface.embeddings.huggingface_endpoint import HuggingFaceEndpointEmbeddings
11
- from langchain.chains import RetrievalQA
12
- from langchain_community.vectorstores import FAISS
13
 
14
  # Load environment variables
15
  load_dotenv()
16
 
17
- api_key=os.environ.get('API_KEY')
18
 
19
- get_token()
 
20
 
21
  # Constants
22
  MAX_TOKENS = 4000
23
- DEFAULT_TEMPERATURE = 0.75
 
 
24
 
25
- # Initialize the OpenAI client
26
  client = OpenAI(
27
- base_url="https://api-inference.huggingface.co/v1",
28
- api_key=api_key
29
  )
30
-
31
  # Create supported models
32
  model_links = {
33
  "Meta-Llama-3.1-8B": "meta-llama/Meta-Llama-3.1-8B-Instruct",
34
- "Mistral-7B-Instruct-v0.3": "mistralai/Mistral-7B-Instruct-v0.3",
35
- "Gemma-2-27b-it": "google/gemma-2-27b-it",
36
  "Falcon-7b-Instruct": "tiiuae/falcon-7b-instruct",
37
  }
38
 
39
- # Load documents and set up RAG pipeline
40
- @st.cache_resource
41
- def setup_rag_pipeline():
42
- loader = HuggingFaceDatasetLoader(
43
- path='Atreyu4EVR/General-BYUI-Data',
44
- page_content_column='content'
45
- )
46
- documents = loader.load()
47
 
48
- hf_embeddings = HuggingFaceEndpointEmbeddings(
49
- model="sentence-transformers/all-MiniLM-L12-v2",
50
- task="feature-extraction",
51
- huggingfacehub_api_token=api_key
52
- )
53
 
54
- vector_store = FAISS.from_documents(documents, hf_embeddings)
55
- retriever = vector_store.as_retriever()
56
 
57
- return retriever
58
 
59
  def reset_conversation():
60
- st.session_state.visible_messages = []
61
- st.session_state.full_context = []
62
-
 
 
 
 
 
 
63
  def main():
64
- st.header('Multi-Models with RAG')
 
65
 
66
  # Sidebar for model selection and temperature
67
  selected_model = st.sidebar.selectbox("Select Model", list(model_links.keys()))
68
  temperature = st.sidebar.slider('Select a temperature value', 0.0, 1.0, DEFAULT_TEMPERATURE)
69
 
70
- st.sidebar.button('Reset Chat', on_click=reset_conversation)
71
 
72
  if "prev_option" not in st.session_state:
73
  st.session_state.prev_option = selected_model
74
 
75
  if st.session_state.prev_option != selected_model:
76
- st.session_state.visible_messages = []
77
- st.session_state.full_context = []
78
  st.session_state.prev_option = selected_model
 
79
 
80
  st.markdown(f'_powered_ by ***:violet[{selected_model}]***')
81
 
82
- # Display model info
83
  st.sidebar.write(f"You're now chatting with **{selected_model}**")
84
  st.sidebar.markdown("*Generated content may be inaccurate or false.*")
85
 
86
- # Initialize chat history
87
- if "visible_messages" not in st.session_state:
88
- st.session_state.visible_messages = []
89
- if "full_context" not in st.session_state:
90
- st.session_state.full_context = []
91
-
92
- # Display chat messages from history on app rerun
93
- for message in st.session_state.visible_messages:
94
  with st.chat_message(message["role"]):
95
  st.markdown(message["content"])
96
 
97
- # Set up RAG pipeline
98
- retriever = setup_rag_pipeline()
99
 
100
  # Chat input and response
101
  if prompt := st.chat_input("Type message here..."):
102
- process_user_input(client, prompt, selected_model, temperature, retriever)
103
 
104
- def process_user_input(client, prompt, selected_model, temperature, retriever):
105
  # Display user message
106
  with st.chat_message("user"):
107
  st.markdown(prompt)
108
- st.session_state.visible_messages.append({"role": "user", "content": prompt})
109
-
110
- # Retrieve relevant documents
111
- relevant_docs = retriever.get_relevant_documents(prompt)
112
- context = "\n".join([doc.page_content for doc in relevant_docs])
113
-
114
- # Prepare full context with system message and retrieved context
115
- full_context = [
116
- {"role": "system", "content": f"You are 'Liahona' an AI chatbot for Brigham Young University-Idaho (BYU-I) students, employees, staff and administrators. Your role is to use the retreived content to form the best response possible to the user's question. Be thorough, helpful, and friendly. Here is content that closely matches the question: {context}"},
117
- *st.session_state.full_context,
118
- {"role": "user", "content": prompt}
119
- ]
120
-
121
- # Update full context in session state
122
- st.session_state.full_context = full_context
123
 
124
  # Generate and display assistant response
125
  with st.chat_message("assistant"):
126
- try:
127
- stream = client.chat.completions.create(
128
- model=model_links[selected_model],
129
- messages=full_context,
130
- temperature=temperature,
131
- stream=True,
132
- max_tokens=MAX_TOKENS,
133
- )
134
- response = st.write_stream(stream)
135
- except Exception as e:
136
- handle_error(e)
137
- return
138
-
139
- # Update visible messages and full context
140
- st.session_state.visible_messages.append({"role": "assistant", "content": response})
141
-
142
- def handle_error(error):
143
  response = """πŸ˜΅β€πŸ’« Looks like someone unplugged something!
144
  \n Either the model space is being updated or something is down."""
145
  st.write(response)
146
- random_dog_pick = random.choice(["broken_llama3.jpeg"])
147
  st.image(random_dog_pick)
148
  st.write("This was the error message:")
149
- st.write(str(error))
150
-
151
- if __name__ == "__main__":
152
- main()
 
1
+ import streamlit as st
2
  import os
3
+ import torch
4
  from openai import OpenAI
5
+ import numpy as np
6
+ import sys
7
  from dotenv import load_dotenv
8
+ import random
9
+ from huggingface_hub import InferenceClient
10
+
 
 
 
 
11
 
12
  # Load environment variables
13
  load_dotenv()
14
 
 
15
 
16
+
17
+
18
 
19
  # Constants
20
  MAX_TOKENS = 4000
21
+ DEFAULT_TEMPERATURE = 0.5
22
+
23
+ # initialize the client
24
 
 
25
  client = OpenAI(
26
+ base_url="https://api-inference.huggingface.co/v1",
27
+ api_key=os.environ.get('API_KEY') # Replace with your token
28
  )
29
+
30
  # Create supported models
31
  model_links = {
32
  "Meta-Llama-3.1-8B": "meta-llama/Meta-Llama-3.1-8B-Instruct",
 
 
33
  "Falcon-7b-Instruct": "tiiuae/falcon-7b-instruct",
34
  }
35
 
 
 
 
 
 
 
 
 
36
 
37
+ # Random dog images for error message
38
+ random_dog_images = ["broken_llama3.jpeg"]
 
 
 
39
 
 
 
40
 
 
41
 
42
  def reset_conversation():
43
+ '''
44
+ Resets Conversation
45
+ '''
46
+ st.session_state.conversation = []
47
+ st.session_state.messages = []
48
+ return None
49
+
50
+ st.sidebar.button('Reset Chat', on_click=reset_conversation) #Reset button
51
+
52
  def main():
53
+ st.header('Multi-Models')
54
+
55
 
56
  # Sidebar for model selection and temperature
57
  selected_model = st.sidebar.selectbox("Select Model", list(model_links.keys()))
58
  temperature = st.sidebar.slider('Select a temperature value', 0.0, 1.0, DEFAULT_TEMPERATURE)
59
 
60
+
61
 
62
  if "prev_option" not in st.session_state:
63
  st.session_state.prev_option = selected_model
64
 
65
  if st.session_state.prev_option != selected_model:
66
+ st.session_state.messages = []
67
+ # st.write(f"Changed to {selected_model}")
68
  st.session_state.prev_option = selected_model
69
+ reset_conversation()
70
 
71
  st.markdown(f'_powered_ by ***:violet[{selected_model}]***')
72
 
73
+ # Display model info and logo
74
  st.sidebar.write(f"You're now chatting with **{selected_model}**")
75
  st.sidebar.markdown("*Generated content may be inaccurate or false.*")
76
 
 
 
 
 
 
 
 
 
77
  with st.chat_message(message["role"]):
78
  st.markdown(message["content"])
79
 
 
 
80
 
81
  # Chat input and response
82
  if prompt := st.chat_input("Type message here..."):
83
+ process_user_input(client, prompt, selected_model, temperature)
84
 
85
+ def process_user_input(client, prompt, selected_model, temperature):
86
  # Display user message
87
  with st.chat_message("user"):
88
  st.markdown(prompt)
89
+ st.session_state.messages.append({"role": "user", "content": prompt})
90
+
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
  # Generate and display assistant response
93
  with st.chat_message("assistant"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  response = """πŸ˜΅β€πŸ’« Looks like someone unplugged something!
95
  \n Either the model space is being updated or something is down."""
96
  st.write(response)
97
+ random_dog_pick = random.choice(random_dog_images)
98
  st.image(random_dog_pick)
99
  st.write("This was the error message:")
100
+ st.write(str(error))