Atreyu4EVR commited on
Commit
9f80d5d
Β·
verified Β·
1 Parent(s): e747f55

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +95 -50
app.py CHANGED
@@ -1,17 +1,24 @@
1
  import streamlit as st
2
  from openai import OpenAI
3
- import torch
4
  import os
5
- import sys
6
- from dotenv import load_dotenv, dotenv_values
7
  import numpy as np
 
 
 
 
 
 
 
 
 
8
 
9
  load_dotenv()
10
 
11
- # Initialize the client
12
  client = OpenAI(
13
  base_url="https://api-inference.huggingface.co/v1",
14
- api_key=os.environ.get('HUGGINGFACEHUB_API_TOKEN') # Replace with your token
15
  )
16
 
17
  # Create supported models
@@ -21,7 +28,6 @@ model_links = {
21
  "gemma-2-2b": "google/gemma-2-2b",
22
  }
23
 
24
- # Pull info about the model to display
25
  model_info = {
26
  "Meta-Llama-3.1-8B": {
27
  'description': """The Llama (3.1) model is a **Large Language Model (LLM)** that's able to have question and answer interactions.
@@ -49,78 +55,117 @@ models = [key for key in model_links.keys()]
49
  # Create the sidebar with the dropdown for model selection
50
  selected_model = st.sidebar.selectbox("Select Model", models)
51
 
52
- # Create a temperature slider
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  temp_values = st.sidebar.slider('Select a temperature value', 0.0, 1.0, 0.5)
54
 
55
- # Create model description
56
  st.sidebar.write(f"You're now chatting with **{selected_model}**")
57
  st.sidebar.markdown(model_info[selected_model]['description'])
58
  st.sidebar.image(model_info[selected_model]['logo'])
59
  st.sidebar.markdown("*Generated content may be inaccurate or false.*")
60
 
61
- if "prev_option" not in st.session_state:
62
- st.session_state.prev_option = selected_model
63
-
64
- if st.session_state.prev_option != selected_model:
65
- st.session_state.messages = []
66
- st.session_state.prev_option = selected_model
67
-
68
- # Pull in the model we want to use
69
- repo_id = model_links[selected_model]
70
-
71
- st.header('Liahona.AI')
72
- st.markdown(f'_powered_ by ***:violet[{selected_model}]***')
73
-
74
- # Set a default model
75
- if selected_model not in st.session_state:
76
- st.session_state[selected_model] = model_links[selected_model]
77
-
78
  # Initialize chat history
79
  if "messages" not in st.session_state:
80
  st.session_state.messages = []
81
 
82
- # Display chat messages from history on app rerun
83
  for message in st.session_state.messages:
84
  with st.chat_message(message["role"]):
85
  st.markdown(message["content"])
86
 
87
- # Accept user input
88
- if prompt := st.chat_input("Type message here..."):
89
 
90
- # Display user message in chat message container
 
 
91
  with st.chat_message("user"):
92
  st.markdown(prompt)
93
- # Add user message to chat history
94
  st.session_state.messages.append({"role": "user", "content": prompt})
95
 
96
- # Display assistant response in chat message container
97
  with st.chat_message("assistant"):
98
-
99
  try:
100
- stream = client.chat.completions.create(
101
- model=repo_id,
102
- messages=[
103
- {"role": m["role"], "content": m["content"]}
104
- for m in st.session_state.messages
105
- ],
106
- temperature=temp_values,
107
- stream=True,
108
- max_tokens=4000,
109
- )
110
-
111
- response = st.write_stream(stream)
112
-
113
  except Exception as e:
114
  response = """πŸ˜΅β€πŸ’« Looks like someone unplugged something!
115
  \n Either the model space is being updated or something is down.
116
- \n
117
- \n Try again later.
118
- \n
119
- \n Here's a random pic of a 🐢:"""
120
  st.write(response)
121
  random_dog_pick = 'https://random.dog/' + random_dog[np.random.randint(len(random_dog))]
122
  st.image(random_dog_pick)
123
  st.write("This was the error message:")
124
- st.write(e)
125
 
126
  st.session_state.messages.append({"role": "assistant", "content": response})
 
1
  import streamlit as st
2
  from openai import OpenAI
 
3
  import os
4
+ from dotenv import load_dotenv
 
5
  import numpy as np
6
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
7
+ from langchain.schema import Document
8
+ from langchain_community.llms import HuggingFaceHub
9
+ from langchain.chains import RetrievalQA, LLMChain
10
+ from langchain.prompts import PromptTemplate
11
+ from langchain_community.embeddings import HuggingFaceInstructEmbeddings
12
+ from langchain_community.vectorstores import Chroma
13
+ from langchain.retrievers import ContextualCompressionRetriever
14
+ from langchain.retrievers.document_compressors import LLMChainExtractor
15
 
16
  load_dotenv()
17
 
18
+ # Initialize the OpenAI client for Hugging Face
19
  client = OpenAI(
20
  base_url="https://api-inference.huggingface.co/v1",
21
+ api_key=os.environ.get('HUGGINGFACEHUB_API_TOKEN')
22
  )
23
 
24
  # Create supported models
 
28
  "gemma-2-2b": "google/gemma-2-2b",
29
  }
30
 
 
31
  model_info = {
32
  "Meta-Llama-3.1-8B": {
33
  'description': """The Llama (3.1) model is a **Large Language Model (LLM)** that's able to have question and answer interactions.
 
55
  # Create the sidebar with the dropdown for model selection
56
  selected_model = st.sidebar.selectbox("Select Model", models)
57
 
58
+ # Function to load and process documents
59
+ def load_and_process_documents(file_path):
60
+ with open(file_path, 'r') as file:
61
+ content = file.read()
62
+
63
+ doc = Document(page_content=content, metadata={"source": file_path})
64
+
65
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=8192, chunk_overlap=200)
66
+ splits = text_splitter.split_documents([doc])
67
+
68
+ return splits
69
+
70
+ # Function to set up the advanced RAG pipeline
71
+ @st.cache_resource
72
+ def setup_advanced_rag_pipeline(model_name):
73
+ # Load and process documents
74
+ splits = load_and_process_documents("index_training.json") # Replace with your document path
75
+
76
+ # Set up InstructorEmbeddings
77
+ embeddings = HuggingFaceInstructEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
78
+
79
+ # Create vectorstore
80
+ vectorstore = Chroma.from_documents(documents=splits, embedding=embeddings)
81
+
82
+ # Set up language model
83
+ llm = HuggingFaceHub(repo_id=model_links[model_name], model_kwargs={"temperature": 0.5, "max_length": 4000})
84
+
85
+ # Set up HyDE
86
+ hyde_prompt = PromptTemplate(
87
+ input_variables=["question"],
88
+ template="Please write a passage to answer the question\nQuestion: {question}\nPassage:"
89
+ )
90
+ hyde_chain = LLMChain(llm=llm, prompt=hyde_prompt)
91
+
92
+ def hyde_retriever(query):
93
+ hypothetical_doc = hyde_chain.run(query)
94
+ hyde_embedding = embeddings.embed_query(hypothetical_doc)
95
+ return vectorstore.similarity_search_by_vector(hyde_embedding, k=3)
96
+
97
+ # Set up ContextualCompressionRetriever
98
+ compressor = LLMChainExtractor.from_llm(llm)
99
+ compression_retriever = ContextualCompressionRetriever(
100
+ base_compressor=compressor,
101
+ base_retriever=hyde_retriever
102
+ )
103
+
104
+ # Create RetrievalQA chain
105
+ qa_chain = RetrievalQA.from_chain_type(
106
+ llm=llm,
107
+ chain_type="stuff",
108
+ retriever=compression_retriever,
109
+ return_source_documents=True
110
+ )
111
+
112
+ return qa_chain
113
+
114
+ # Streamlit app
115
+ st.header('Liahona.AI')
116
+
117
+ # Sidebar for model selection
118
+ selected_model = st.sidebar.selectbox("Select Model", list(model_links.keys()))
119
+ st.markdown(f'_powered_ by ***:violet[{selected_model}]***')
120
+
121
+ # Temperature slider
122
  temp_values = st.sidebar.slider('Select a temperature value', 0.0, 1.0, 0.5)
123
 
124
+ # Display model info
125
  st.sidebar.write(f"You're now chatting with **{selected_model}**")
126
  st.sidebar.markdown(model_info[selected_model]['description'])
127
  st.sidebar.image(model_info[selected_model]['logo'])
128
  st.sidebar.markdown("*Generated content may be inaccurate or false.*")
129
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
  # Initialize chat history
131
  if "messages" not in st.session_state:
132
  st.session_state.messages = []
133
 
134
+ # Display chat messages from history
135
  for message in st.session_state.messages:
136
  with st.chat_message(message["role"]):
137
  st.markdown(message["content"])
138
 
139
+ # Set up advanced RAG pipeline
140
+ qa_chain = setup_advanced_rag_pipeline(selected_model)
141
 
142
+ # Chat input
143
+ if prompt := st.chat_input("Type message here..."):
144
+ # Display user message
145
  with st.chat_message("user"):
146
  st.markdown(prompt)
 
147
  st.session_state.messages.append({"role": "user", "content": prompt})
148
 
149
+ # Generate and display assistant response
150
  with st.chat_message("assistant"):
 
151
  try:
152
+ result = qa_chain({"query": prompt})
153
+ response = result["result"]
154
+ st.write(response)
155
+
156
+ # Optionally, display source documents
157
+ st.expander("View Source Documents"):
158
+ for doc in result["source_documents"]:
159
+ st.write(doc.page_content)
160
+ st.write("---")
 
 
 
 
161
  except Exception as e:
162
  response = """πŸ˜΅β€πŸ’« Looks like someone unplugged something!
163
  \n Either the model space is being updated or something is down.
164
+ \n"""
 
 
 
165
  st.write(response)
166
  random_dog_pick = 'https://random.dog/' + random_dog[np.random.randint(len(random_dog))]
167
  st.image(random_dog_pick)
168
  st.write("This was the error message:")
169
+ st.write(str(e))
170
 
171
  st.session_state.messages.append({"role": "assistant", "content": response})