Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,152 +1,100 @@
|
|
|
|
1 |
import os
|
2 |
-
import
|
3 |
from openai import OpenAI
|
4 |
-
import
|
|
|
5 |
from dotenv import load_dotenv
|
6 |
-
|
7 |
-
from
|
8 |
-
|
9 |
-
from langchain_community.document_loaders.hugging_face_dataset import HuggingFaceDatasetLoader
|
10 |
-
from langchain_huggingface.embeddings.huggingface_endpoint import HuggingFaceEndpointEmbeddings
|
11 |
-
from langchain.chains import RetrievalQA
|
12 |
-
from langchain_community.vectorstores import FAISS
|
13 |
|
14 |
# Load environment variables
|
15 |
load_dotenv()
|
16 |
|
17 |
-
api_key=os.environ.get('API_KEY')
|
18 |
|
19 |
-
|
|
|
20 |
|
21 |
# Constants
|
22 |
MAX_TOKENS = 4000
|
23 |
-
DEFAULT_TEMPERATURE = 0.
|
|
|
|
|
24 |
|
25 |
-
# Initialize the OpenAI client
|
26 |
client = OpenAI(
|
27 |
-
|
28 |
-
|
29 |
)
|
30 |
-
|
31 |
# Create supported models
|
32 |
model_links = {
|
33 |
"Meta-Llama-3.1-8B": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
34 |
-
"Mistral-7B-Instruct-v0.3": "mistralai/Mistral-7B-Instruct-v0.3",
|
35 |
-
"Gemma-2-27b-it": "google/gemma-2-27b-it",
|
36 |
"Falcon-7b-Instruct": "tiiuae/falcon-7b-instruct",
|
37 |
}
|
38 |
|
39 |
-
# Load documents and set up RAG pipeline
|
40 |
-
@st.cache_resource
|
41 |
-
def setup_rag_pipeline():
|
42 |
-
loader = HuggingFaceDatasetLoader(
|
43 |
-
path='Atreyu4EVR/General-BYUI-Data',
|
44 |
-
page_content_column='content'
|
45 |
-
)
|
46 |
-
documents = loader.load()
|
47 |
|
48 |
-
|
49 |
-
|
50 |
-
task="feature-extraction",
|
51 |
-
huggingfacehub_api_token=api_key
|
52 |
-
)
|
53 |
|
54 |
-
vector_store = FAISS.from_documents(documents, hf_embeddings)
|
55 |
-
retriever = vector_store.as_retriever()
|
56 |
|
57 |
-
return retriever
|
58 |
|
59 |
def reset_conversation():
|
60 |
-
|
61 |
-
|
62 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
def main():
|
64 |
-
st.header('Multi-Models
|
|
|
65 |
|
66 |
# Sidebar for model selection and temperature
|
67 |
selected_model = st.sidebar.selectbox("Select Model", list(model_links.keys()))
|
68 |
temperature = st.sidebar.slider('Select a temperature value', 0.0, 1.0, DEFAULT_TEMPERATURE)
|
69 |
|
70 |
-
|
71 |
|
72 |
if "prev_option" not in st.session_state:
|
73 |
st.session_state.prev_option = selected_model
|
74 |
|
75 |
if st.session_state.prev_option != selected_model:
|
76 |
-
st.session_state.
|
77 |
-
st.
|
78 |
st.session_state.prev_option = selected_model
|
|
|
79 |
|
80 |
st.markdown(f'_powered_ by ***:violet[{selected_model}]***')
|
81 |
|
82 |
-
# Display model info
|
83 |
st.sidebar.write(f"You're now chatting with **{selected_model}**")
|
84 |
st.sidebar.markdown("*Generated content may be inaccurate or false.*")
|
85 |
|
86 |
-
# Initialize chat history
|
87 |
-
if "visible_messages" not in st.session_state:
|
88 |
-
st.session_state.visible_messages = []
|
89 |
-
if "full_context" not in st.session_state:
|
90 |
-
st.session_state.full_context = []
|
91 |
-
|
92 |
-
# Display chat messages from history on app rerun
|
93 |
-
for message in st.session_state.visible_messages:
|
94 |
with st.chat_message(message["role"]):
|
95 |
st.markdown(message["content"])
|
96 |
|
97 |
-
# Set up RAG pipeline
|
98 |
-
retriever = setup_rag_pipeline()
|
99 |
|
100 |
# Chat input and response
|
101 |
if prompt := st.chat_input("Type message here..."):
|
102 |
-
process_user_input(client, prompt, selected_model, temperature
|
103 |
|
104 |
-
def process_user_input(client, prompt, selected_model, temperature
|
105 |
# Display user message
|
106 |
with st.chat_message("user"):
|
107 |
st.markdown(prompt)
|
108 |
-
st.session_state.
|
109 |
-
|
110 |
-
# Retrieve relevant documents
|
111 |
-
relevant_docs = retriever.get_relevant_documents(prompt)
|
112 |
-
context = "\n".join([doc.page_content for doc in relevant_docs])
|
113 |
-
|
114 |
-
# Prepare full context with system message and retrieved context
|
115 |
-
full_context = [
|
116 |
-
{"role": "system", "content": f"You are 'Liahona' an AI chatbot for Brigham Young University-Idaho (BYU-I) students, employees, staff and administrators. Your role is to use the retreived content to form the best response possible to the user's question. Be thorough, helpful, and friendly. Here is content that closely matches the question: {context}"},
|
117 |
-
*st.session_state.full_context,
|
118 |
-
{"role": "user", "content": prompt}
|
119 |
-
]
|
120 |
-
|
121 |
-
# Update full context in session state
|
122 |
-
st.session_state.full_context = full_context
|
123 |
|
124 |
# Generate and display assistant response
|
125 |
with st.chat_message("assistant"):
|
126 |
-
try:
|
127 |
-
stream = client.chat.completions.create(
|
128 |
-
model=model_links[selected_model],
|
129 |
-
messages=full_context,
|
130 |
-
temperature=temperature,
|
131 |
-
stream=True,
|
132 |
-
max_tokens=MAX_TOKENS,
|
133 |
-
)
|
134 |
-
response = st.write_stream(stream)
|
135 |
-
except Exception as e:
|
136 |
-
handle_error(e)
|
137 |
-
return
|
138 |
-
|
139 |
-
# Update visible messages and full context
|
140 |
-
st.session_state.visible_messages.append({"role": "assistant", "content": response})
|
141 |
-
|
142 |
-
def handle_error(error):
|
143 |
response = """π΅βπ« Looks like someone unplugged something!
|
144 |
\n Either the model space is being updated or something is down."""
|
145 |
st.write(response)
|
146 |
-
random_dog_pick = random.choice(
|
147 |
st.image(random_dog_pick)
|
148 |
st.write("This was the error message:")
|
149 |
-
st.write(str(error))
|
150 |
-
|
151 |
-
if __name__ == "__main__":
|
152 |
-
main()
|
|
|
1 |
+
import streamlit as st
|
2 |
import os
|
3 |
+
import torch
|
4 |
from openai import OpenAI
|
5 |
+
import numpy as np
|
6 |
+
import sys
|
7 |
from dotenv import load_dotenv
|
8 |
+
import random
|
9 |
+
from huggingface_hub import InferenceClient
|
10 |
+
|
|
|
|
|
|
|
|
|
11 |
|
12 |
# Load environment variables
|
13 |
load_dotenv()
|
14 |
|
|
|
15 |
|
16 |
+
|
17 |
+
|
18 |
|
19 |
# Constants
|
20 |
MAX_TOKENS = 4000
|
21 |
+
DEFAULT_TEMPERATURE = 0.5
|
22 |
+
|
23 |
+
# initialize the client
|
24 |
|
|
|
25 |
client = OpenAI(
|
26 |
+
base_url="https://api-inference.huggingface.co/v1",
|
27 |
+
api_key=os.environ.get('API_KEY') # Replace with your token
|
28 |
)
|
29 |
+
|
30 |
# Create supported models
|
31 |
model_links = {
|
32 |
"Meta-Llama-3.1-8B": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
|
|
|
|
33 |
"Falcon-7b-Instruct": "tiiuae/falcon-7b-instruct",
|
34 |
}
|
35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
|
37 |
+
# Random dog images for error message
|
38 |
+
random_dog_images = ["broken_llama3.jpeg"]
|
|
|
|
|
|
|
39 |
|
|
|
|
|
40 |
|
|
|
41 |
|
42 |
def reset_conversation():
|
43 |
+
'''
|
44 |
+
Resets Conversation
|
45 |
+
'''
|
46 |
+
st.session_state.conversation = []
|
47 |
+
st.session_state.messages = []
|
48 |
+
return None
|
49 |
+
|
50 |
+
st.sidebar.button('Reset Chat', on_click=reset_conversation) #Reset button
|
51 |
+
|
52 |
def main():
|
53 |
+
st.header('Multi-Models')
|
54 |
+
|
55 |
|
56 |
# Sidebar for model selection and temperature
|
57 |
selected_model = st.sidebar.selectbox("Select Model", list(model_links.keys()))
|
58 |
temperature = st.sidebar.slider('Select a temperature value', 0.0, 1.0, DEFAULT_TEMPERATURE)
|
59 |
|
60 |
+
|
61 |
|
62 |
if "prev_option" not in st.session_state:
|
63 |
st.session_state.prev_option = selected_model
|
64 |
|
65 |
if st.session_state.prev_option != selected_model:
|
66 |
+
st.session_state.messages = []
|
67 |
+
# st.write(f"Changed to {selected_model}")
|
68 |
st.session_state.prev_option = selected_model
|
69 |
+
reset_conversation()
|
70 |
|
71 |
st.markdown(f'_powered_ by ***:violet[{selected_model}]***')
|
72 |
|
73 |
+
# Display model info and logo
|
74 |
st.sidebar.write(f"You're now chatting with **{selected_model}**")
|
75 |
st.sidebar.markdown("*Generated content may be inaccurate or false.*")
|
76 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
with st.chat_message(message["role"]):
|
78 |
st.markdown(message["content"])
|
79 |
|
|
|
|
|
80 |
|
81 |
# Chat input and response
|
82 |
if prompt := st.chat_input("Type message here..."):
|
83 |
+
process_user_input(client, prompt, selected_model, temperature)
|
84 |
|
85 |
+
def process_user_input(client, prompt, selected_model, temperature):
|
86 |
# Display user message
|
87 |
with st.chat_message("user"):
|
88 |
st.markdown(prompt)
|
89 |
+
st.session_state.messages.append({"role": "user", "content": prompt})
|
90 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
|
92 |
# Generate and display assistant response
|
93 |
with st.chat_message("assistant"):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
response = """π΅βπ« Looks like someone unplugged something!
|
95 |
\n Either the model space is being updated or something is down."""
|
96 |
st.write(response)
|
97 |
+
random_dog_pick = random.choice(random_dog_images)
|
98 |
st.image(random_dog_pick)
|
99 |
st.write("This was the error message:")
|
100 |
+
st.write(str(error))
|
|
|
|
|
|