cawacci commited on
Commit
d2791db
1 Parent(s): 2178651

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +138 -70
app.py CHANGED
@@ -11,6 +11,7 @@ import os
11
  import time
12
  import gc # メモリ解放
13
  import re # 正規表現で文章をクリーンアップ
 
14
 
15
  # HuggingFace
16
  import torch
@@ -115,6 +116,55 @@ class SessionState:
115
 
116
  self.cache_clear()
117
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  # --------------------------------------
119
  # 自作TextSplitter(テキストをLLMのトークン数内に分割)
120
  # (参考)https://www.sato-susumu.com/entry/2023/04/30/131338
@@ -157,10 +207,21 @@ def name_detector(text: str) -> list:
157
 
158
  node = node.next
159
 
160
- names = list(set(names))
 
161
 
162
  return names
163
 
 
 
 
 
 
 
 
 
 
 
164
  # --------------------------------------
165
  # DeepL でメモリを翻訳しトークン数を削減(OpenAIモデル利用時)
166
  # --------------------------------------
@@ -207,21 +268,12 @@ def deepl_memory(ss: SessionState) -> (SessionState):
207
  def web_search(ss: SessionState, query) -> (SessionState, str):
208
 
209
  search = DuckDuckGoSearchRun(verbose=True)
 
 
210
 
211
  for i in range(3):
212
  web_result = search(query)
213
 
214
- # 人名の抽出
215
- names = []
216
- names.extend(name_detector(query))
217
- names.extend(name_detector(web_result))
218
- if len(names)==0:
219
- names = ""
220
- elif len(names)==1:
221
- names = names[0]
222
- else:
223
- names = ", ".join(names)
224
-
225
  if ss.current_model == "gpt-3.5-turbo":
226
  text = [query, web_result]
227
  params = {
@@ -235,21 +287,33 @@ def web_search(ss: SessionState, query) -> (SessionState, str):
235
  request = requests.post(DEEPL_API_ENDPOINT, data=params)
236
  response = request.json()
237
 
238
- query = response["translations"][0]["text"]
239
- web_result = response["translations"][1]["text"]
240
- web_result = ss.web_summary_chain({'query': query, 'context': web_result})['text']
241
- if web_result != "NO INFO":
 
 
242
  break
243
 
 
 
 
 
 
 
 
 
 
 
244
  if names != "":
245
  web_query = f"""
246
- {query}
247
- Use the following Suggested Answer Source as a reliable reference to answer the question above in Japanese. When translating names of people, refer to Names as a translation guide.
248
- Suggested Answer Source: {web_result}
249
  Names: {names}
250
  """.strip()
251
  else:
252
- web_query = query + "\nUse the following Suggested Answer Source as a reliable reference to answer the question above in the Japanese.\n===\nSuggested Answer Source: " + web_result + "\n"
253
 
254
 
255
  return ss, web_query
@@ -265,29 +329,19 @@ def web_search(ss: SessionState, query) -> (SessionState, str):
265
  # --------------------------------------
266
  # Conversation Chain Template
267
  # --------------------------------------
268
-
269
  # Tokens: OpenAI 104/ Llama 105 <- In Japanese: Tokens: OpenAI 191/ Llama 162
270
- # sys_chat_message = """
271
- # You are an outstanding AI concierge. Understand the intent of the customer's questions based on
272
- # the conversation history. Then, answer them with many specific and detailed information in Japanese.
273
- # If you do not know the answer to a question, do make up an answer and says
274
- # "誠に申し訳ございませんが、その点についてはわかりかねます".
275
- # """.replace("\n", "")
276
 
277
  sys_chat_message = """
278
- You are an outstanding AI concierge.
279
- 1) Understand the intent of the customer's questions based on the conversation history.
280
- 2) Then, by using references if available, answer the question with many specific and detailed information in Japanese.
281
- 3) If the reference does not provide answer to the question at all, and you do not know the answer, do make up an answer and says "誠に申し訳ございませんが、その点についてはわかりかねます".
282
- """.strip()
283
 
284
  chat_common_format = """
285
  ===
286
  Question: {query}
287
- ===
288
- Conversation History:
289
- {chat_history}
290
- ===
291
  日本語の回答: """
292
 
293
  chat_template_std = f"{sys_chat_message}{chat_common_format}"
@@ -297,35 +351,46 @@ chat_template_llama2 = f"<s>[INST] <<SYS>>{sys_chat_message}<</SYS>>{chat_common
297
  # QA Chain Template (Stuff)
298
  # --------------------------------------
299
  # Tokens: OpenAI 113/ Llama 111 <- In Japanese: Tokens: OpenAI 256/ Llama 225
300
- sys_qa_message = """
301
- You are an AI concierge who carefully answers questions from customers based on references.
302
- Understand the intent of the customer's questions based on the conversation history. Then, give
303
- a specific answer in Japanese using sentences extracted from the following references. If you do
304
- not know the answer, do not make up an answer and reply, "誠に申し訳ございませんが、その点についてはわかりかねます".
305
- """.replace("\n", "")
 
 
 
 
 
 
 
 
 
 
 
306
 
307
  qa_common_format = """
308
  ===
309
  Question: {query}
310
  References: {context}
311
- ===
312
- Conversation History:
313
- {chat_history}
314
- ===
315
  日本語の回答: """
316
 
 
 
317
 
318
- qa_template_std = f"{sys_qa_message}{qa_common_format}"
319
- qa_template_llama2 = f"<s>[INST] <<SYS>>{sys_qa_message}<</SYS>>{qa_common_format}[/INST]"
320
 
321
  # --------------------------------------
322
  # QA Chain Template (Map Reduce)
323
  # --------------------------------------
324
  # 1. 会話履歴と最新の質問から、質問文を生成するchain のプロンプト
325
  query_generator_message = """
326
- Referring to the "Conversation History", reformat the user's "Additional Question"
327
- to a specific question by filling in the missing subject, verb, objects, complements,
328
- and other necessary information to get a better search result. Answer in Japanese.
 
329
  """.replace("\n", "")
330
 
331
  query_generator_common_format = """
@@ -334,30 +399,25 @@ query_generator_common_format = """
334
  {chat_history}
335
 
336
  [Additional Question] {query}
337
- 明確な日本語の質問文: """
338
 
339
  query_generator_template_std = f"{query_generator_message}{query_generator_common_format}"
340
  query_generator_template_llama2 = f"<s>[INST] <<SYS>>{query_generator_message}<</SYS>>{query_generator_common_format}[/INST]"
341
 
342
 
343
  # 2. 生成された質問文を用いて、参考文献を要約するchain のプロンプト
344
- # question_prompt_message = """
345
- # From the following references, extract key information relevant to the question
346
- # and summarize it in a natural English sentence with clear subject, verb, object,
347
- # and complement. If there is no information in the reference that answers the question,
348
- # do not summarize and simply answer "NO INFO"
349
- # """.replace("\n", "")
350
 
351
  question_prompt_message = """
352
- 1. Determine if any of the following references provide information that answers the Question, and if there is no information, answer "NO INFO" and stop.
353
- 2. From the following references, extract key information relevant to the question and summarize it in a natural English sentence with clear subject, verb, object, and complement.
354
- """.strip()
 
355
 
356
  question_prompt_common_format = """
357
  ===
358
  [Question] {query}
359
- [references] {context}
360
- [Answer]"""
361
 
362
  question_prompt_template_std = f"{question_prompt_message}{question_prompt_common_format}"
363
  question_prompt_template_llama2 = f"<s>[INST] <<SYS>>{question_prompt_message}<</SYS>>{question_prompt_common_format}[/INST]"
@@ -578,11 +638,12 @@ def set_chains(ss: SessionState, summarization_mode) -> SessionState:
578
  ss.query_generator = LLMChain(llm=ss.llm, prompt=query_generator_prompt, verbose=True)
579
 
580
  if ss.conversation_chain is None:
581
- chat_prompt = PromptTemplate(input_variables=['query', 'chat_history'], template=chat_template)
 
582
  ss.conversation_chain = ConversationChain(
583
  llm = ss.llm,
584
  prompt = chat_prompt,
585
- memory = ss.memory,
586
  input_key = "query",
587
  output_key = "output_text",
588
  verbose = True,
@@ -590,13 +651,16 @@ def set_chains(ss: SessionState, summarization_mode) -> SessionState:
590
 
591
  if ss.qa_chain is None:
592
  if summarization_mode == "stuff":
593
- qa_prompt = PromptTemplate(input_variables=['context', 'query', 'chat_history'], template=qa_template)
594
- ss.qa_chain = load_qa_chain(ss.llm, chain_type="stuff", memory=ss.memory, prompt=qa_prompt)
 
 
595
 
596
  elif summarization_mode == "map_reduce":
597
  question_prompt = PromptTemplate(template=question_template, input_variables=["context", "query"])
598
  combine_prompt = PromptTemplate(template=combine_template, input_variables=["summaries", "query"])
599
- ss.qa_chain = load_qa_chain(ss.llm, chain_type="map_reduce", return_map_steps=True, memory=ss.memory, question_prompt=question_prompt, combine_prompt=combine_prompt)
 
600
 
601
  if ss.web_summary_chain is None:
602
  question_prompt = PromptTemplate(template=question_template, input_variables=["context", "query"])
@@ -853,6 +917,8 @@ def bot(ss: SessionState, query, qa_flag, web_flag, summarization_mode) -> (Sess
853
 
854
  def chat_predict(ss: SessionState, query) -> SessionState:
855
  response = ss.conversation_chain.predict(query=query)
 
 
856
  ss.dialogue[-1] = (ss.dialogue[-1][0], response)
857
  return ss
858
 
@@ -890,10 +956,12 @@ def qa_predict(ss: SessionState, query) -> SessionState:
890
  if result["output_text"] != "":
891
  response = result["output_text"] + sources
892
  ss.dialogue[-1] = (ss.dialogue[-1][0], response)
 
 
893
  return ss
894
- else:
895
  # 空欄の場合は直近の履歴を削除してやり直し
896
- ss.memory.chat_memory.messages = ss.memory.chat_memory.messages[:-2]
897
 
898
  # 3回の試行後も空欄の場合
899
  response = "3回試行しましたが、情報製生成できませんでした。"
 
11
  import time
12
  import gc # メモリ解放
13
  import re # 正規表現で文章をクリーンアップ
14
+ import regex # 漢字抽出で利用
15
 
16
  # HuggingFace
17
  import torch
 
116
 
117
  self.cache_clear()
118
 
119
+ # --------------------------------------
120
+ # メモリを使用しない ConversationChainを自作
121
+ # --------------------------------------
122
+ from typing import Dict, List
123
+
124
+ from langchain.chains.conversation.prompt import PROMPT
125
+ from langchain.chains.llm import LLMChain
126
+ from langchain.pydantic_v1 import Extra, Field, root_validator
127
+ from langchain.schema import BasePromptTemplate
128
+
129
+ class ConversationChain(LLMChain):
130
+ """Chain to have a conversation without loading context from memory.
131
+
132
+ Example:
133
+ .. code-block:: python
134
+
135
+ from langchain import ConversationChainWithoutMemory, OpenAI
136
+
137
+ conversation = ConversationChainWithoutMemory(llm=OpenAI())
138
+ """
139
+
140
+ prompt: BasePromptTemplate = PROMPT
141
+ """Default conversation prompt to use."""
142
+
143
+ input_key: str = "input" #: :meta private:
144
+ output_key: str = "response" #: :meta private:
145
+
146
+ class Config:
147
+ """Configuration for this pydantic object."""
148
+
149
+ extra = Extra.forbid
150
+ arbitrary_types_allowed = True
151
+
152
+ @property
153
+ def input_keys(self) -> List[str]:
154
+ """Use this since so some prompt vars come from history."""
155
+ return [self.input_key]
156
+
157
+ @root_validator()
158
+ def validate_prompt_input_variables(cls, values: Dict) -> Dict:
159
+ """Validate that prompt input variables are consistent without memory."""
160
+ input_key = values["input_key"]
161
+ prompt_variables = values["prompt"].input_variables
162
+ if input_key not in prompt_variables:
163
+ raise ValueError(
164
+ f"The prompt expects {prompt_variables}, but {input_key} is not found."
165
+ )
166
+ return values
167
+
168
  # --------------------------------------
169
  # 自作TextSplitter(テキストをLLMのトークン数内に分割)
170
  # (参考)https://www.sato-susumu.com/entry/2023/04/30/131338
 
207
 
208
  node = node.next
209
 
210
+ # ユニークな値を抽出し、その後漢字を含む値のみとする
211
+ names = filter_kanji(list(set(names)))
212
 
213
  return names
214
 
215
+ # --------------------------------------
216
+ # リストから漢字を含む値だけを抽出する
217
+ # --------------------------------------
218
+ def filter_kanji(lst) -> list:
219
+ def contains_kanji(s):
220
+ p = regex.compile(r'\p{Script=Han}+')
221
+ return bool(p.search(s))
222
+
223
+ return [item for item in lst if contains_kanji(item)]
224
+
225
  # --------------------------------------
226
  # DeepL でメモリを翻訳しトークン数を削減(OpenAIモデル利用時)
227
  # --------------------------------------
 
268
  def web_search(ss: SessionState, query) -> (SessionState, str):
269
 
270
  search = DuckDuckGoSearchRun(verbose=True)
271
+ names = []
272
+ names.extend(name_detector(query))
273
 
274
  for i in range(3):
275
  web_result = search(query)
276
 
 
 
 
 
 
 
 
 
 
 
 
277
  if ss.current_model == "gpt-3.5-turbo":
278
  text = [query, web_result]
279
  params = {
 
287
  request = requests.post(DEEPL_API_ENDPOINT, data=params)
288
  response = request.json()
289
 
290
+ query_eng = response["translations"][0]["text"]
291
+ web_result_eng = response["translations"][1]["text"]
292
+ web_result_eng = ss.web_summary_chain({'query': query_eng, 'context': web_result_eng})['text']
293
+ if "$$NO INFO$$" in web_result_eng:
294
+ web_result_eng = ss.web_summary_chain({'query': query_eng, 'context': web_result_eng})['text']
295
+ if "$$NO INFO$$" not in web_result_eng:
296
  break
297
 
298
+ # 検索結果から人名を抽出し、テキスト化
299
+ names.extend(name_detector(web_result))
300
+ if len(names)==0:
301
+ names = ""
302
+ elif len(names)==1:
303
+ names = names[0]
304
+ else:
305
+ names = ", ".join(names)
306
+
307
+ # Web検索結果を含むQueryを渡す。
308
  if names != "":
309
  web_query = f"""
310
+ {query_eng}
311
+ Use the following Suggested Answer as a reference to answer the question above in Japanese. When translating names of people, refer to Names as a translation guide.
312
+ Suggested Answer: {web_result_eng}
313
  Names: {names}
314
  """.strip()
315
  else:
316
+ web_query = query_eng + "\nUse the following Suggested Answer as a reference to answer the question above in the Japanese.\n===\nSuggested Answer: " + web_result_eng + "\n"
317
 
318
 
319
  return ss, web_query
 
329
  # --------------------------------------
330
  # Conversation Chain Template
331
  # --------------------------------------
 
332
  # Tokens: OpenAI 104/ Llama 105 <- In Japanese: Tokens: OpenAI 191/ Llama 162
 
 
 
 
 
 
333
 
334
  sys_chat_message = """
335
+ You are an AI concierge who carefully answers questions from customers based on references.
336
+ You understand what the customer wants to know, and give many specific details in Japanese
337
+ using sentences extracted from the following references when available. If you do not know
338
+ the answer, do not make up an answer and reply, "誠に申し訳ございませんが、その点についてはわかりかねます".
339
+ """.replace("\n", "")
340
 
341
  chat_common_format = """
342
  ===
343
  Question: {query}
344
+
 
 
 
345
  日本語の回答: """
346
 
347
  chat_template_std = f"{sys_chat_message}{chat_common_format}"
 
351
  # QA Chain Template (Stuff)
352
  # --------------------------------------
353
  # Tokens: OpenAI 113/ Llama 111 <- In Japanese: Tokens: OpenAI 256/ Llama 225
354
+ # sys_qa_message = """
355
+ # You are an AI concierge who carefully answers questions from customers based on references.
356
+ # You understand what the customer wants to know from the Conversation History and Question,
357
+ # and give a specific answer in Japanese using sentences extracted from the following references.
358
+ # If you do not know the answer, do not make up an answer and reply,
359
+ # "誠に申し訳ございませんが、その点についてはわかりかねます".
360
+ # """.replace("\n", "")
361
+
362
+ # qa_common_format = """
363
+ # ===
364
+ # Question: {query}
365
+ # References: {context}
366
+ # ===
367
+ # Conversation History:
368
+ # {chat_history}
369
+ # ===
370
+ # 日本語の回答: """
371
 
372
  qa_common_format = """
373
  ===
374
  Question: {query}
375
  References: {context}
376
+
 
 
 
377
  日本語の回答: """
378
 
379
+ qa_template_std = f"{sys_chat_message}{qa_common_format}"
380
+ qa_template_llama2 = f"<s>[INST] <<SYS>>{sys_chat_message}<</SYS>>{qa_common_format}[/INST]"
381
 
382
+ # qa_template_std = f"{sys_qa_message}{qa_common_format}"
383
+ # qa_template_llama2 = f"<s>[INST] <<SYS>>{sys_qa_message}<</SYS>>{qa_common_format}[/INST]"
384
 
385
  # --------------------------------------
386
  # QA Chain Template (Map Reduce)
387
  # --------------------------------------
388
  # 1. 会話履歴と最新の質問から、質問文を生成するchain のプロンプト
389
  query_generator_message = """
390
+ Referring to the "Conversation History", especially to the most recent conversation,
391
+ reformat the user's "Additional Question" into a specific question in Japanese by
392
+ filling in the missing subject, verb, objects, complements,and other necessary
393
+ information to get a better search result. Answer in 日本語(Japanese).
394
  """.replace("\n", "")
395
 
396
  query_generator_common_format = """
 
399
  {chat_history}
400
 
401
  [Additional Question] {query}
402
+ 明確な質問文: """
403
 
404
  query_generator_template_std = f"{query_generator_message}{query_generator_common_format}"
405
  query_generator_template_llama2 = f"<s>[INST] <<SYS>>{query_generator_message}<</SYS>>{query_generator_common_format}[/INST]"
406
 
407
 
408
  # 2. 生成された質問文を用いて、参考文献を要約するchain のプロンプト
 
 
 
 
 
 
409
 
410
  question_prompt_message = """
411
+ From the following references, extract key information relevant to the question
412
+ and summarize it in a natural English sentence with clear subject, verb, object,
413
+ and complement.
414
+ """.replace("\n", "")
415
 
416
  question_prompt_common_format = """
417
  ===
418
  [Question] {query}
419
+ [References] {context}
420
+ [Key Information] """
421
 
422
  question_prompt_template_std = f"{question_prompt_message}{question_prompt_common_format}"
423
  question_prompt_template_llama2 = f"<s>[INST] <<SYS>>{question_prompt_message}<</SYS>>{question_prompt_common_format}[/INST]"
 
638
  ss.query_generator = LLMChain(llm=ss.llm, prompt=query_generator_prompt, verbose=True)
639
 
640
  if ss.conversation_chain is None:
641
+ # chat_prompt = PromptTemplate(input_variables=['query', 'chat_history'], template=chat_template)
642
+ chat_prompt = PromptTemplate(input_variables=['query'], template=chat_template)
643
  ss.conversation_chain = ConversationChain(
644
  llm = ss.llm,
645
  prompt = chat_prompt,
646
+ # memory = ss.memory,
647
  input_key = "query",
648
  output_key = "output_text",
649
  verbose = True,
 
651
 
652
  if ss.qa_chain is None:
653
  if summarization_mode == "stuff":
654
+ # qa_prompt = PromptTemplate(input_variables=['context', 'query', 'chat_history'], template=qa_template)
655
+ qa_prompt = PromptTemplate(input_variables=['context', 'query'], template=qa_template)
656
+ # ss.qa_chain = load_qa_chain(ss.llm, chain_type="stuff", memory=ss.memory, prompt=qa_prompt)
657
+ ss.qa_chain = load_qa_chain(ss.llm, chain_type="stuff", prompt=qa_prompt, verbose=True)
658
 
659
  elif summarization_mode == "map_reduce":
660
  question_prompt = PromptTemplate(template=question_template, input_variables=["context", "query"])
661
  combine_prompt = PromptTemplate(template=combine_template, input_variables=["summaries", "query"])
662
+ # ss.qa_chain = load_qa_chain(ss.llm, chain_type="map_reduce", return_map_steps=True, memory=ss.memory, question_prompt=question_prompt, combine_prompt=combine_prompt)
663
+ ss.qa_chain = load_qa_chain(ss.llm, chain_type="map_reduce", return_map_steps=True, question_prompt=question_prompt, combine_prompt=combine_prompt, verbose=True)
664
 
665
  if ss.web_summary_chain is None:
666
  question_prompt = PromptTemplate(template=question_template, input_variables=["context", "query"])
 
917
 
918
  def chat_predict(ss: SessionState, query) -> SessionState:
919
  response = ss.conversation_chain.predict(query=query)
920
+ ss.memory.chat_memory.add_user_message(query)
921
+ ss.memory.chat_memory.add_ai_message(response)
922
  ss.dialogue[-1] = (ss.dialogue[-1][0], response)
923
  return ss
924
 
 
956
  if result["output_text"] != "":
957
  response = result["output_text"] + sources
958
  ss.dialogue[-1] = (ss.dialogue[-1][0], response)
959
+ ss.memory.chat_memory.add_user_message(original_query)
960
+ ss.memory.chat_memory.add_ai_message(response)
961
  return ss
962
+ # else:
963
  # 空欄の場合は直近の履歴を削除してやり直し
964
+ # ss.memory.chat_memory.messages = ss.memory.chat_memory.messages[:-2]
965
 
966
  # 3回の試行後も空欄の場合
967
  response = "3回試行しましたが、情報製生成できませんでした。"