StevenChen16 commited on
Commit
5872c96
1 Parent(s): 1df3e06

修改bug并且删除vector_store的重复初始化

Browse files
Files changed (1) hide show
  1. app.py +22 -29
app.py CHANGED
@@ -94,7 +94,7 @@ try:
94
 
95
  except Exception as e:
96
  raise RuntimeError(f"Failed to load vector store from HuggingFace Hub: {str(e)}")
97
- # vector_store = FAISS.load_local(repo_path, embedding_model, allow_dangerous_deserialization=True)
98
 
99
 
100
  background_prompt = '''
@@ -155,19 +155,14 @@ Now, please guide me step by step to describe the legal issues I am facing, acco
155
 
156
  def query_vector_store(vector_store: FAISS, query, k=4, relevance_threshold=0.8):
157
  """
158
- 从向量存储中查询相似文档。
159
- 参数:
160
- vector_store (FAISS): 向量存储实例
161
- query (str): 查询内容
162
- k (int): 返回文档数量
163
- relevance_threshold (float): 相关性阈值
164
- 返回:
165
- context (list): 查询到的上下文内容
166
  """
167
- retriever = vector_store.as_retriever(search_type="similarity_score_threshold", search_kwargs={"score_threshold": relevance_threshold, "k": k})
 
168
  similar_docs = retriever.invoke(query)
169
  context = [doc.page_content for doc in similar_docs]
170
- return context
 
171
 
172
  @spaces.GPU(duration=120)
173
  def chat_llama3_8b(message: str,
@@ -177,40 +172,39 @@ def chat_llama3_8b(message: str,
177
  ) -> str:
178
  """
179
  Generate a streaming response using the llama3-8b model.
180
- Args:
181
- message (str): The input message.
182
- history (list): The conversation history used by ChatInterface.
183
- temperature (float): The temperature for generating the response.
184
- max_new_tokens (int): The maximum number of new tokens to generate.
185
- Returns:
186
- str: The generated response.
187
  """
 
188
  citation = query_vector_store(vector_store, message, 4, 0.7)
189
- if citation != None:
190
- context = "Based on this citations: " + citation + "please answer questions:"
191
  conversation = []
192
  for user, assistant in history:
193
- # content = background_prompt + user
194
- conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
195
- if citation != None:
196
- message = background_prompt + context + message
 
 
 
 
197
  else:
198
- message = background_prompt + message
 
199
  conversation.append({"role": "user", "content": message})
200
 
 
201
  input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt").to(model.device)
202
-
203
  streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
204
 
205
  generate_kwargs = dict(
206
- input_ids= input_ids,
207
  streamer=streamer,
208
  max_new_tokens=max_new_tokens,
209
  do_sample=True,
210
  temperature=temperature,
211
  eos_token_id=terminators,
212
  )
213
- # This will enforce greedy generation (do_sample=False) when the temperature is passed 0, avoiding the crash.
214
  if temperature == 0:
215
  generate_kwargs['do_sample'] = False
216
 
@@ -220,7 +214,6 @@ def chat_llama3_8b(message: str,
220
  outputs = []
221
  for text in streamer:
222
  outputs.append(text)
223
- #print(outputs)
224
  yield "".join(outputs)
225
 
226
 
 
94
 
95
  except Exception as e:
96
  raise RuntimeError(f"Failed to load vector store from HuggingFace Hub: {str(e)}")
97
+ vector_store = FAISS.load_local(repo_path, embedding_model, allow_dangerous_deserialization=True)
98
 
99
 
100
  background_prompt = '''
 
155
 
156
  def query_vector_store(vector_store: FAISS, query, k=4, relevance_threshold=0.8):
157
  """
158
+ Query similar documents from vector store.
 
 
 
 
 
 
 
159
  """
160
+ retriever = vector_store.as_retriever(search_type="similarity_score_threshold",
161
+ search_kwargs={"score_threshold": relevance_threshold, "k": k})
162
  similar_docs = retriever.invoke(query)
163
  context = [doc.page_content for doc in similar_docs]
164
+ # Join the context list into a single string
165
+ return " ".join(context) if context else ""
166
 
167
  @spaces.GPU(duration=120)
168
  def chat_llama3_8b(message: str,
 
172
  ) -> str:
173
  """
174
  Generate a streaming response using the llama3-8b model.
 
 
 
 
 
 
 
175
  """
176
+ # Get citations from vector store
177
  citation = query_vector_store(vector_store, message, 4, 0.7)
178
+
179
+ # Build conversation history
180
  conversation = []
181
  for user, assistant in history:
182
+ conversation.extend([
183
+ {"role": "user", "content": user},
184
+ {"role": "assistant", "content": assistant}
185
+ ])
186
+
187
+ # Construct the final message with background prompt and citations
188
+ if citation:
189
+ message = f"{background_prompt}Based on these citations: {citation}\nPlease answer question: {message}"
190
  else:
191
+ message = f"{background_prompt}{message}"
192
+
193
  conversation.append({"role": "user", "content": message})
194
 
195
+ # Generate response
196
  input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt").to(model.device)
 
197
  streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
198
 
199
  generate_kwargs = dict(
200
+ input_ids=input_ids,
201
  streamer=streamer,
202
  max_new_tokens=max_new_tokens,
203
  do_sample=True,
204
  temperature=temperature,
205
  eos_token_id=terminators,
206
  )
207
+
208
  if temperature == 0:
209
  generate_kwargs['do_sample'] = False
210
 
 
214
  outputs = []
215
  for text in streamer:
216
  outputs.append(text)
 
217
  yield "".join(outputs)
218
 
219