cawacci commited on
Commit
6f80afe
1 Parent(s): 6a3ce6b

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +96 -41
  2. requirements.txt +3 -1
app.py CHANGED
@@ -22,6 +22,7 @@ from langchain.embeddings.openai import OpenAIEmbeddings
22
  from langchain.chat_models import ChatOpenAI
23
 
24
  # LangChain
 
25
  from langchain.llms import HuggingFacePipeline
26
  from transformers import pipeline
27
 
@@ -45,8 +46,8 @@ import gradio as gr
45
  from pypdf import PdfReader
46
  import requests # DeepL API request
47
 
48
- # test
49
- import langchain # (debug=Trueにするため)
50
 
51
  # --------------------------------------
52
  # ユーザ別セッションの変数値を記録するクラス
@@ -69,6 +70,7 @@ class SessionState:
69
  self.conversation_chain = None # ConversationChain
70
  self.query_generator = None # Query Refiner with Chat history
71
  self.qa_chain = None # load_qa_chain
 
72
  self.embedded_urls = []
73
  self.similarity_search_k = None # No. of similarity search documents to find.
74
  self.summarization_mode = None # Stuff / Map Reduce / Refine
@@ -132,6 +134,33 @@ text_splitter = JPTextSplitter(
132
  chunk_overlap = chunk_overlap, # オーバーラップの最大文字数
133
  )
134
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  # --------------------------------------
136
  # DeepL でメモリを翻訳しトークン数を削減(OpenAIモデル利用時)
137
  # --------------------------------------
@@ -175,11 +204,22 @@ def deepl_memory(ss: SessionState) -> (SessionState):
175
  # DEEPL_API_ENDPOINT = "https://api-free.deepl.com/v2/translate"
176
  # DEEPL_API_KEY = os.getenv("DEEPL_API_KEY")
177
 
178
- def web_search(query, current_model) -> str:
179
- search = DuckDuckGoSearchRun()
180
  web_result = search(query)
181
 
182
- if current_model == "gpt-3.5-turbo":
 
 
 
 
 
 
 
 
 
 
 
183
  text = [query, web_result]
184
  params = {
185
  "auth_key": DEEPL_API_KEY,
@@ -193,19 +233,28 @@ def web_search(query, current_model) -> str:
193
  response = request.json()
194
 
195
  query = response["translations"][0]["text"]
196
- web_result = response["translations"][1]["text"]
 
 
 
 
 
 
 
 
 
 
 
197
 
198
- web_query = query + "\nUse the following information as a reference to answer the question above in the Japanese.\n===\nReference: " + web_result + "\n==="
199
 
200
- return web_query
201
 
202
  # --------------------------------------
203
  # LangChain カスタムプロンプト各種
204
  # llama tokenizer
205
- # https://belladoreai.github.io/llama-tokenizer-js/example-demo/build/
206
-
207
  # OpenAI tokenizer
208
- # https://platform.openai.com/tokenizer
209
  # --------------------------------------
210
 
211
  # --------------------------------------
@@ -214,19 +263,18 @@ def web_search(query, current_model) -> str:
214
 
215
  # Tokens: OpenAI 104/ Llama 105 <- In Japanese: Tokens: OpenAI 191/ Llama 162
216
  sys_chat_message = """
217
- The following is a conversation between an AI concierge and a customer.
218
- The AI understands what the customer wants to know from the conversation history and the latest question,
219
- and gives many specific details in Japanese. If the AI does not know the answer to a question, it does not
220
- make up an answer and says "誠に申し訳ございませんが、その点についてはわかりかねます".
221
  """.replace("\n", "")
222
 
223
  chat_common_format = """
224
  ===
225
  Question: {query}
226
-
227
- Conversation History:
228
  {chat_history}
229
-
230
  日本語の回答: """
231
 
232
  chat_template_std = f"{sys_chat_message}{chat_common_format}"
@@ -238,21 +286,23 @@ chat_template_llama2 = f"<s>[INST] <<SYS>>{sys_chat_message}<</SYS>>{chat_common
238
  # Tokens: OpenAI 113/ Llama 111 <- In Japanese: Tokens: OpenAI 256/ Llama 225
239
  sys_qa_message = """
240
  You are an AI concierge who carefully answers questions from customers based on references.
241
- You understand what the customer wants to know from the Conversation History and Question,
242
- and give a specific answer in Japanese using sentences extracted from the following references.
243
- If you do not know the answer, do not make up an answer and reply,
244
- "誠に申し訳ございませんが、その点についてはわかりかねます".
245
  """.replace("\n", "")
246
 
247
  qa_common_format = """
248
  ===
249
  Question: {query}
250
  References: {context}
251
- Conversation History:
 
252
  {chat_history}
253
-
254
  日本語の回答: """
255
 
 
256
  qa_template_std = f"{sys_qa_message}{qa_common_format}"
257
  qa_template_llama2 = f"<s>[INST] <<SYS>>{sys_qa_message}<</SYS>>{qa_common_format}[/INST]"
258
 
@@ -262,8 +312,8 @@ qa_template_llama2 = f"<s>[INST] <<SYS>>{sys_qa_message}<</SYS>>{qa_common_forma
262
  # 1. 会話履歴と最新の質問から、質問文を生成するchain のプロンプト
263
  query_generator_message = """
264
  Referring to the "Conversation History", reformat the user's "Additional Question"
265
- to a specific question in Japanese by filling in the missing subject, verb, objects,
266
- complements, and other necessary information to get a better search result.
267
  """.replace("\n", "")
268
 
269
  query_generator_common_format = """
@@ -272,7 +322,7 @@ query_generator_common_format = """
272
  {chat_history}
273
 
274
  [Additional Question] {query}
275
- 明確な質問文: """
276
 
277
  query_generator_template_std = f"{query_generator_message}{query_generator_common_format}"
278
  query_generator_template_llama2 = f"<s>[INST] <<SYS>>{query_generator_message}<</SYS>>{query_generator_common_format}[/INST]"
@@ -287,8 +337,8 @@ and complement.
287
 
288
  question_prompt_common_format = """
289
  ===
290
- [references] {context}
291
  [Question] {query}
 
292
  [Summary] """
293
 
294
  question_prompt_template_std = f"{question_prompt_message}{question_prompt_common_format}"
@@ -305,17 +355,14 @@ If you do not know the answer, do not make up an answer and reply,
305
 
306
  combine_prompt_common_format = """
307
  ===
308
- Question:
309
- {query}
310
- ===
311
  Reference: {summaries}
312
- ===
313
  日本語の回答: """
314
 
 
315
  combine_prompt_template_std = f"{combine_prompt_message}{combine_prompt_common_format}"
316
  combine_prompt_template_llama2 = f"<s>[INST] <<SYS>>{combine_prompt_message}<</SYS>>{combine_prompt_common_format}[/INST]"
317
 
318
-
319
  # --------------------------------------
320
  # ConversationSummaryBufferMemoryの要約プロンプト
321
  # ソース → https://github.com/langchain-ai/langchain/blob/894c272a562471aadc1eb48e4a2992923533dea0/langchain/memory/prompt.py#L26-L49
@@ -508,6 +555,10 @@ def set_chains(ss: SessionState, summarization_mode) -> SessionState:
508
  # --------------------------------------
509
  # Conversation/QAチェーンの設定
510
  # --------------------------------------
 
 
 
 
511
  if ss.conversation_chain is None:
512
  chat_prompt = PromptTemplate(input_variables=['query', 'chat_history'], template=chat_template)
513
  ss.conversation_chain = ConversationChain(
@@ -525,13 +576,14 @@ def set_chains(ss: SessionState, summarization_mode) -> SessionState:
525
  ss.qa_chain = load_qa_chain(ss.llm, chain_type="stuff", memory=ss.memory, prompt=qa_prompt)
526
 
527
  elif summarization_mode == "map_reduce":
528
- query_generator_prompt = PromptTemplate(template=query_generator_template, input_variables = ["chat_history", "query"])
529
- ss.query_generator = LLMChain(llm=ss.llm, prompt=query_generator_prompt)
530
-
531
  question_prompt = PromptTemplate(template=question_template, input_variables=["context", "query"])
532
  combine_prompt = PromptTemplate(template=combine_template, input_variables=["summaries", "query"])
533
  ss.qa_chain = load_qa_chain(ss.llm, chain_type="map_reduce", return_map_steps=True, memory=ss.memory, question_prompt=question_prompt, combine_prompt=combine_prompt)
534
 
 
 
 
 
535
  return ss
536
 
537
  def initialize_db(ss: SessionState) -> SessionState:
@@ -761,16 +813,16 @@ def bot(ss: SessionState, query, qa_flag, web_flag, summarization_mode) -> (Sess
761
  # QA Model
762
  if qa_flag is True and ss.embeddings is not None and ss.db is not None:
763
  if web_flag:
764
- web_query = web_search(query, ss.current_model)
765
  ss = qa_predict(ss, web_query)
766
  ss.memory.chat_memory.messages[-2].content = query
767
  else:
768
- ss = qa_predict(ss, query) # LLMで回答を生成
769
 
770
  # Chat Model
771
  else:
772
  if web_flag:
773
- web_query = web_search(query, ss.current_model)
774
  ss = chat_predict(ss, web_query)
775
  ss.memory.chat_memory.messages[-2].content = query
776
  else:
@@ -788,6 +840,8 @@ def chat_predict(ss: SessionState, query) -> SessionState:
788
 
789
  def qa_predict(ss: SessionState, query) -> SessionState:
790
 
 
 
791
  # Rinnaモデル向けの設定(クエリの改行コード修正)
792
  if ss.current_model == "rinna/bilingual-gpt-neox-4b-instruction-sft":
793
  query = query.strip().replace("\n", "<NL>")
@@ -829,7 +883,7 @@ def qa_predict(ss: SessionState, query) -> SessionState:
829
  response += "参考文献の抽出には成功していますので、言語モデルを変えてお試しください。"
830
 
831
  # ユーザーメッセージと AI メッセージの追加
832
- ss.memory.chat_memory.add_user_message(query.replace("<NL>", "\n"))
833
  ss.memory.chat_memory.add_ai_message(response)
834
  ss.dialogue[-1] = (ss.dialogue[-1][0], response) # 会話履歴
835
  return ss
@@ -1028,4 +1082,5 @@ with gr.Blocks() as demo:
1028
 
1029
  if __name__ == "__main__":
1030
  demo.queue(concurrency_count=5)
1031
- demo.launch(debug=True)
 
 
22
  from langchain.chat_models import ChatOpenAI
23
 
24
  # LangChain
25
+ import langchain
26
  from langchain.llms import HuggingFacePipeline
27
  from transformers import pipeline
28
 
 
46
  from pypdf import PdfReader
47
  import requests # DeepL API request
48
 
49
+ # Mecab
50
+ import MeCab
51
 
52
  # --------------------------------------
53
  # ユーザ別セッションの変数値を記録するクラス
 
70
  self.conversation_chain = None # ConversationChain
71
  self.query_generator = None # Query Refiner with Chat history
72
  self.qa_chain = None # load_qa_chain
73
+ self.web_summary_chain = None # Summarize web search result
74
  self.embedded_urls = []
75
  self.similarity_search_k = None # No. of similarity search documents to find.
76
  self.summarization_mode = None # Stuff / Map Reduce / Refine
 
134
  chunk_overlap = chunk_overlap, # オーバーラップの最大文字数
135
  )
136
 
137
+ # --------------------------------------
138
+ # 文中から人名を抽出
139
+ # --------------------------------------
140
+ def name_detector(text: str) -> list:
141
+ mecab = MeCab.Tagger()
142
+ mecab.parse('') # ←バグ対応
143
+ node = mecab.parseToNode(text).next
144
+ names = []
145
+
146
+ while node:
147
+ if node.feature.split(',')[3] == "姓":
148
+ if node.next and node.next.feature.split(',')[3] == "名":
149
+ names.append(str(node.surface) + str(node.next.surface))
150
+ else:
151
+ names.append(node.surface)
152
+ if node.feature.split(',')[3] == "名":
153
+ if node.prev and node.prev.feature.split(',')[3] == "姓":
154
+ pass
155
+ else:
156
+ names.append(str(node.surface))
157
+
158
+ node = node.next
159
+
160
+ names = list(set(names))
161
+
162
+ return names
163
+
164
  # --------------------------------------
165
  # DeepL でメモリを翻訳しトークン数を削減(OpenAIモデル利用時)
166
  # --------------------------------------
 
204
  # DEEPL_API_ENDPOINT = "https://api-free.deepl.com/v2/translate"
205
  # DEEPL_API_KEY = os.getenv("DEEPL_API_KEY")
206
 
207
+ def web_search(ss: SessionState, query) -> (SessionState, str):
208
+ search = DuckDuckGoSearchRun(verbose=True)
209
  web_result = search(query)
210
 
211
+ # 人名の抽出
212
+ names = []
213
+ names.extend(name_detector(query))
214
+ names.extend(name_detector(web_result))
215
+ if len(names)==0:
216
+ names = ""
217
+ elif len(names)==1:
218
+ names = names[0]
219
+ else:
220
+ names = ", ".join(names)
221
+
222
+ if ss.current_model == "gpt-3.5-turbo":
223
  text = [query, web_result]
224
  params = {
225
  "auth_key": DEEPL_API_KEY,
 
233
  response = request.json()
234
 
235
  query = response["translations"][0]["text"]
236
+ web_result = response["translations"][1]["text"]
237
+ web_result = ss.web_summary_chain({'query': query, 'context': web_result})['text']
238
+
239
+ if names != "":
240
+ web_query = f"""
241
+ {query}
242
+ Use the following information as a reference to answer the question above in Japanese. When translating names of Japanese people, refer to Japanese Names as a translation guide.
243
+ Reference: {web_result}
244
+ Japanese Names: {names}
245
+ """.strip()
246
+ else:
247
+ web_query = query + "\nUse the following information as a reference to answer the question above in the Japanese.\n===\nReference: " + web_result + "\n==="
248
 
 
249
 
250
+ return ss, web_query
251
 
252
  # --------------------------------------
253
  # LangChain カスタムプロンプト各種
254
  # llama tokenizer
255
+ # https://belladoreai.github.io/llama-tokenizer-js/example-demo/build/
 
256
  # OpenAI tokenizer
257
+ # https://platform.openai.com/tokenizer
258
  # --------------------------------------
259
 
260
  # --------------------------------------
 
263
 
264
  # Tokens: OpenAI 104/ Llama 105 <- In Japanese: Tokens: OpenAI 191/ Llama 162
265
  sys_chat_message = """
266
+ You are an outstanding AI concierge. You understand your customers' needs from their questions and answer
267
+ them with many specific and detailed information in Japanese. If you do not know the answer to a question,
268
+ do make up an answer and says "誠に申し訳ございませんが、その点についてはわかりかねます". Ignore Conversation History.
 
269
  """.replace("\n", "")
270
 
271
  chat_common_format = """
272
  ===
273
  Question: {query}
274
+ ===
275
+ Conversation History(Ignore):
276
  {chat_history}
277
+ ===
278
  日本語の回答: """
279
 
280
  chat_template_std = f"{sys_chat_message}{chat_common_format}"
 
286
  # Tokens: OpenAI 113/ Llama 111 <- In Japanese: Tokens: OpenAI 256/ Llama 225
287
  sys_qa_message = """
288
  You are an AI concierge who carefully answers questions from customers based on references.
289
+ You understand what the customer wants to know from Question, and give a specific answer in
290
+ Japanese using sentences extracted from the following references. If you do not know the answer,
291
+ do not make up an answer and reply, "誠に申し訳ございませんが、その点についてはわかりかねます".
292
+ Ignore Conversation History.
293
  """.replace("\n", "")
294
 
295
  qa_common_format = """
296
  ===
297
  Question: {query}
298
  References: {context}
299
+ ===
300
+ Conversation History(Ignore):
301
  {chat_history}
302
+ ===
303
  日本語の回答: """
304
 
305
+
306
  qa_template_std = f"{sys_qa_message}{qa_common_format}"
307
  qa_template_llama2 = f"<s>[INST] <<SYS>>{sys_qa_message}<</SYS>>{qa_common_format}[/INST]"
308
 
 
312
  # 1. 会話履歴と最新の質問から、質問文を生成するchain のプロンプト
313
  query_generator_message = """
314
  Referring to the "Conversation History", reformat the user's "Additional Question"
315
+ to a specific question by filling in the missing subject, verb, objects, complements,
316
+ and other necessary information to get a better search result. Answer in Japanese.
317
  """.replace("\n", "")
318
 
319
  query_generator_common_format = """
 
322
  {chat_history}
323
 
324
  [Additional Question] {query}
325
+ 明確な日本語の質問文: """
326
 
327
  query_generator_template_std = f"{query_generator_message}{query_generator_common_format}"
328
  query_generator_template_llama2 = f"<s>[INST] <<SYS>>{query_generator_message}<</SYS>>{query_generator_common_format}[/INST]"
 
337
 
338
  question_prompt_common_format = """
339
  ===
 
340
  [Question] {query}
341
+ [references] {context}
342
  [Summary] """
343
 
344
  question_prompt_template_std = f"{question_prompt_message}{question_prompt_common_format}"
 
355
 
356
  combine_prompt_common_format = """
357
  ===
358
+ Question: {query}
 
 
359
  Reference: {summaries}
 
360
  日本語の回答: """
361
 
362
+
363
  combine_prompt_template_std = f"{combine_prompt_message}{combine_prompt_common_format}"
364
  combine_prompt_template_llama2 = f"<s>[INST] <<SYS>>{combine_prompt_message}<</SYS>>{combine_prompt_common_format}[/INST]"
365
 
 
366
  # --------------------------------------
367
  # ConversationSummaryBufferMemoryの要約プロンプト
368
  # ソース → https://github.com/langchain-ai/langchain/blob/894c272a562471aadc1eb48e4a2992923533dea0/langchain/memory/prompt.py#L26-L49
 
555
  # --------------------------------------
556
  # Conversation/QAチェーンの設定
557
  # --------------------------------------
558
+ if ss.query_generator is None:
559
+ query_generator_prompt = PromptTemplate(template=query_generator_template, input_variables = ["chat_history", "query"])
560
+ ss.query_generator = LLMChain(llm=ss.llm, prompt=query_generator_prompt, verbose=True)
561
+
562
  if ss.conversation_chain is None:
563
  chat_prompt = PromptTemplate(input_variables=['query', 'chat_history'], template=chat_template)
564
  ss.conversation_chain = ConversationChain(
 
576
  ss.qa_chain = load_qa_chain(ss.llm, chain_type="stuff", memory=ss.memory, prompt=qa_prompt)
577
 
578
  elif summarization_mode == "map_reduce":
 
 
 
579
  question_prompt = PromptTemplate(template=question_template, input_variables=["context", "query"])
580
  combine_prompt = PromptTemplate(template=combine_template, input_variables=["summaries", "query"])
581
  ss.qa_chain = load_qa_chain(ss.llm, chain_type="map_reduce", return_map_steps=True, memory=ss.memory, question_prompt=question_prompt, combine_prompt=combine_prompt)
582
 
583
+ if ss.web_summary_chain is None:
584
+ question_prompt = PromptTemplate(template=question_template, input_variables=["context", "query"])
585
+ ss.web_summary_chain = LLMChain(llm=ss.llm, prompt=question_prompt, verbose=True)
586
+
587
  return ss
588
 
589
  def initialize_db(ss: SessionState) -> SessionState:
 
813
  # QA Model
814
  if qa_flag is True and ss.embeddings is not None and ss.db is not None:
815
  if web_flag:
816
+ ss, web_query = web_search(ss, query)
817
  ss = qa_predict(ss, web_query)
818
  ss.memory.chat_memory.messages[-2].content = query
819
  else:
820
+ ss = qa_predict(ss, query)
821
 
822
  # Chat Model
823
  else:
824
  if web_flag:
825
+ ss, web_query = web_search(ss, query)
826
  ss = chat_predict(ss, web_query)
827
  ss.memory.chat_memory.messages[-2].content = query
828
  else:
 
840
 
841
  def qa_predict(ss: SessionState, query) -> SessionState:
842
 
843
+ original_query = query
844
+
845
  # Rinnaモデル向けの設定(クエリの改行コード修正)
846
  if ss.current_model == "rinna/bilingual-gpt-neox-4b-instruction-sft":
847
  query = query.strip().replace("\n", "<NL>")
 
883
  response += "参考文献の抽出には成功していますので、言語モデルを変えてお試しください。"
884
 
885
  # ユーザーメッセージと AI メッセージの追加
886
+ ss.memory.chat_memory.add_user_message(original_query.replace("<NL>", "\n"))
887
  ss.memory.chat_memory.add_ai_message(response)
888
  ss.dialogue[-1] = (ss.dialogue[-1][0], response) # 会話履歴
889
  return ss
 
1082
 
1083
  if __name__ == "__main__":
1084
  demo.queue(concurrency_count=5)
1085
+ demo.launch(debug=True,)
1086
+
requirements.txt CHANGED
@@ -21,4 +21,6 @@ numpy==1.23.5
21
  pandas==1.5.3
22
  chromedriver-autoinstaller
23
  chromedriver-binary
24
- duckduckgo-search==3.8.5
 
 
 
21
  pandas==1.5.3
22
  chromedriver-autoinstaller
23
  chromedriver-binary
24
+ duckduckgo-search==3.8.5
25
+ mecab-python3==1.0.6
26
+ unidic-lite==1.0.8