cawacci commited on
Commit
65753fd
1 Parent(s): a89078f

Upload 5 files

Browse files
Files changed (5) hide show
  1. app.py +1022 -0
  2. bear.png +0 -0
  3. packages.txt +1 -0
  4. penguin.png +0 -0
  5. requirements.txt +23 -0
app.py ADDED
@@ -0,0 +1,1022 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Chat with Documents 2 by cawacci
2
+ # 2023.9.10 キカガク長期コース(2023年4月期)の成果物アプリとして制作
3
+
4
+ # --------------------------------------
5
+ # Libraries
6
+ # --------------------------------------
7
+ import os
8
+ import time
9
+ import gc # メモリ解放
10
+ import re # 正規表現で文章をクリーンアップ
11
+
12
+ # HuggingFace
13
+ import torch
14
+ from transformers import AutoTokenizer, AutoModelForCausalLM
15
+
16
+ # OpenAI
17
+ import openai
18
+ from langchain.embeddings.openai import OpenAIEmbeddings
19
+ from langchain.chat_models import ChatOpenAI
20
+
21
+ # LangChain
22
+ from langchain.llms import HuggingFacePipeline
23
+ from transformers import pipeline
24
+
25
+ from langchain.embeddings import HuggingFaceEmbeddings
26
+ from langchain.chains import LLMChain, VectorDBQA
27
+ from langchain.vectorstores import Chroma
28
+
29
+ from langchain import PromptTemplate, ConversationChain
30
+ from langchain.chains.question_answering import load_qa_chain # QA Chat
31
+ from langchain.document_loaders import SeleniumURLLoader # URL取得
32
+ from langchain.docstore.document import Document # テキストをドキュメント化
33
+ from langchain.memory import ConversationSummaryBufferMemory # チャット履歴
34
+
35
+ from typing import Any
36
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
37
+
38
+ from langchain.tools import DuckDuckGoSearchRun
39
+
40
+ # Gradio
41
+ import gradio as gr
42
+ from pypdf import PdfReader
43
+ import requests # DeepL API request
44
+
45
+ # test
46
+ import langchain # (debug=Trueにするため)
47
+
48
+ # --------------------------------------
49
+ # ユーザ別セッションの変数値を記録するクラス
50
+ #  (参考)https://blog.shikoan.com/gradio-state/
51
+ # --------------------------------------
52
+ class SessionState:
53
+ def __init__(self):
54
+ # Hugging Face
55
+ self.tokenizer = None
56
+ self.pipe = None
57
+ self.model = None
58
+
59
+ # LangChain
60
+ self.llm = None
61
+ self.embeddings = None
62
+ self.current_model = ""
63
+ self.current_embedding = ""
64
+ self.db = None # Vector DB
65
+ self.memory = None # Langchain Chat Memory
66
+ self.conversation_chain = None # ConversationChain
67
+ self.query_generator = None # Query Refiner with Chat history
68
+ self.qa_chain = None # load_qa_chain
69
+ self.embedded_urls = []
70
+ self.similarity_search_k = None # No. of similarity search documents to find.
71
+ self.summarization_mode = None # Stuff / Map Reduce / Refine
72
+
73
+ # Apps
74
+ self.dialogue = [] # Recent Chat History for display
75
+
76
+ # --------------------------------------
77
+ # Empty Cache
78
+ # --------------------------------------
79
+ def cache_clear(self):
80
+ if torch.cuda.is_available():
81
+ torch.cuda.empty_cache() # GPU Memory Clear
82
+
83
+ gc.collect() # CPU Memory Clear
84
+
85
+ # --------------------------------------
86
+ # Clear Models (llm: llm model, embd: embeddings, db: vectordb)
87
+ # --------------------------------------
88
+ def clear_memory(self, llm=False, embd=False, db=False):
89
+ # DB
90
+ if db and self.db:
91
+ self.db.delete_collection()
92
+ self.db = None
93
+ self.embedded_urls = []
94
+
95
+ # Embeddings model
96
+ if llm or embd:
97
+ self.embeddings = None
98
+ self.current_embedding = ""
99
+ self.qa_chain = None
100
+
101
+ # LLM model
102
+ if llm:
103
+ self.llm = None
104
+ self.pipe = None
105
+ self.model = None
106
+ self.current_model = ""
107
+ self.tokenizer = None
108
+ self.memory = None
109
+ self.chat_history = [] # ←必要性を要検証
110
+
111
+ self.cache_clear()
112
+
113
+ # --------------------------------------
114
+ # 自作TextSplitter(テキストをLLMのトークン数内に分割)
115
+ # (参考)https://www.sato-susumu.com/entry/2023/04/30/131338
116
+ #  → 「!」、「?」、「)」、「.」、「!」、「?」、「,」などを追加
117
+ # --------------------------------------
118
+ class JPTextSplitter(RecursiveCharacterTextSplitter):
119
+ def __init__(self, **kwargs: Any):
120
+ separators = ["\n\n", "\n", "。", "!", "?", ")","、", ".", "!", "?", ",", " ", ""]
121
+ super().__init__(separators=separators, **kwargs)
122
+
123
+ # チャンクの分割
124
+ chunk_size = 512
125
+ chunk_overlap = 35
126
+
127
+ text_splitter = JPTextSplitter(
128
+ chunk_size = chunk_size, # チャンクの最大文字数
129
+ chunk_overlap = chunk_overlap, # オーバーラップの最大文字数
130
+ )
131
+
132
+ # --------------------------------------
133
+ # DeepL でメモリを翻訳しトークン数を削減(OpenAIモデル利用時)
134
+ # --------------------------------------
135
+ DEEPL_API_ENDPOINT = "https://api-free.deepl.com/v2/translate"
136
+ DEEPL_API_KEY = os.getenv("DEEPL_API_KEY")
137
+
138
+ def deepl_memory(ss: SessionState) -> (SessionState):
139
+ if ss.current_model == "gpt-3.5-turbo":
140
+ # メモリから会話履歴を取得
141
+ user_message = ss.memory.chat_memory.messages[-2].content
142
+ ai_message = ss.memory.chat_memory.messages[-1].content
143
+ text = [user_message, ai_message]
144
+
145
+ # DeepL設定
146
+ params = {
147
+ "auth_key": DEEPL_API_KEY,
148
+ "text": text,
149
+ "target_lang": "EN",
150
+ "source_lang": "JA",
151
+ "tag_handling": "xml",
152
+ "igonere_tags": "x",
153
+ }
154
+ request = requests.post(DEEPL_API_ENDPOINT, data=params)
155
+ request.raise_for_status() # 応答のステータスコードがエラーの場合は例外を発生させます。
156
+ response = request.json()
157
+
158
+ # JSONから翻訳文を取得
159
+ user_message = response["translations"][0]["text"]
160
+ ai_message = response["translations"][1]["text"]
161
+
162
+ # memoryの最後の会話を削除し、翻訳文を追加
163
+ ss.memory.chat_memory.messages = ss.memory.chat_memory.messages[:-2]
164
+ ss.memory.chat_memory.add_user_message(user_message)
165
+ ss.memory.chat_memory.add_ai_message(ai_message)
166
+
167
+ return ss
168
+
169
+ # --------------------------------------
170
+ # DuckDuckGo Web検索結果を入力プロンプトに追加
171
+ # --------------------------------------
172
+ # DEEPL_API_ENDPOINT = "https://api-free.deepl.com/v2/translate"
173
+ # DEEPL_API_KEY = os.getenv("DEEPL_API_KEY")
174
+
175
+ def web_search(query, current_model) -> str:
176
+ search = DuckDuckGoSearchRun()
177
+ web_result = search(query)
178
+
179
+ if current_model == "gpt-3.5-turbo":
180
+ text = [query, web_result]
181
+ params = {
182
+ "auth_key": DEEPL_API_KEY,
183
+ "text": text,
184
+ "target_lang": "EN",
185
+ "source_lang": "JA",
186
+ "tag_handling": "xml",
187
+ "igonere_tags": "x",
188
+ }
189
+ request = requests.post(DEEPL_API_ENDPOINT, data=params)
190
+ response = request.json()
191
+
192
+ query = response["translations"][0]["text"]
193
+ web_result = response["translations"][1]["text"]
194
+
195
+ web_query = query + "\nUse the following information as a reference to answer the question above in the Japanese.\n===\nReference: " + web_result + "\n==="
196
+
197
+ return web_query
198
+
199
+ # --------------------------------------
200
+ # LangChain カスタムプロンプト各種
201
+ # llama tokenizer
202
+ # https://belladoreai.github.io/llama-tokenizer-js/example-demo/build/
203
+
204
+ # OpenAI tokenizer
205
+ # https://platform.openai.com/tokenizer
206
+ # --------------------------------------
207
+
208
+ # --------------------------------------
209
+ # Conversation Chain Template
210
+ # --------------------------------------
211
+
212
+ # Tokens: OpenAI 104/ Llama 105 <- In Japanese: Tokens: OpenAI 191/ Llama 162
213
+ sys_chat_message = """
214
+ The following is a conversation between an AI concierge and a customer.
215
+ The AI understands what the customer wants to know from the conversation history and the latest question,
216
+ and gives many specific details in Japanese. If the AI does not know the answer to a question, it does not
217
+ make up an answer and says "誠に申し訳ございませんが、その点についてはわかりかねます".
218
+ """.replace("\n", "")
219
+
220
+ chat_common_format = """
221
+ ===
222
+ Question: {query}
223
+
224
+ Conversation History:
225
+ {chat_history}
226
+
227
+ 日本語の回答: """
228
+
229
+ chat_template_std = f"{sys_chat_message}{chat_common_format}"
230
+ chat_template_llama2 = f"<s>[INST] <<SYS>>{sys_chat_message}<</SYS>>{chat_common_format}[/INST]"
231
+
232
+ # --------------------------------------
233
+ # QA Chain Template (Stuff)
234
+ # --------------------------------------
235
+ # Tokens: OpenAI 113/ Llama 111 <- In Japanese: Tokens: OpenAI 256/ Llama 225
236
+ sys_qa_message = """
237
+ You are an AI concierge who carefully answers questions from customers based on references.
238
+ You understand what the customer wants to know from the Conversation History and Question,
239
+ and give a specific answer in Japanese using sentences extracted from the following references.
240
+ If you do not know the answer, do not make up an answer and reply,
241
+ "誠に申し訳ございませんが、その点についてはわかりかねます".
242
+ """.replace("\n", "")
243
+
244
+ qa_common_format = """
245
+ ===
246
+ Question: {query}
247
+ References: {context}
248
+ Conversation History:
249
+ {chat_history}
250
+
251
+ 日本語の回答: """
252
+
253
+ qa_template_std = f"{sys_qa_message}{qa_common_format}"
254
+ qa_template_llama2 = f"<s>[INST] <<SYS>>{sys_qa_message}<</SYS>>{qa_common_format}[/INST]"
255
+
256
+ # --------------------------------------
257
+ # QA Chain Template (Map Reduce)
258
+ # --------------------------------------
259
+ # 1. 会話履歴と最新の質問から、質問文を生成するchain のプロンプト
260
+ query_generator_message = """
261
+ Referring to the "Conversation History", reformat the user's "Additional Question"
262
+ to a specific question in Japanese by filling in the missing subject, verb, objects,
263
+ complements, and other necessary information to get a better search result.
264
+ """.replace("\n", "")
265
+
266
+ query_generator_common_format = """
267
+ ===
268
+ [Conversation History]
269
+ {chat_history}
270
+
271
+ [Additional Question] {query}
272
+ 明確な質問文: """
273
+
274
+ query_generator_template_std = f"{query_generator_message}{query_generator_common_format}"
275
+ query_generator_template_llama2 = f"<s>[INST] <<SYS>>{query_generator_message}<</SYS>>{query_generator_common_format}[/INST]"
276
+
277
+
278
+ # 2. 生成された質問文を用いて、参考文献を要約するchain のプロンプト
279
+ question_prompt_message = """
280
+ From the following references, extract key information relevant to the question
281
+ and summarize it in a natural English sentence with clear subject, verb, object,
282
+ and complement.
283
+ """.replace("\n", "")
284
+
285
+ question_prompt_common_format = """
286
+ ===
287
+ [references] {context}
288
+ [Question] {query}
289
+ [Summary] """
290
+
291
+ question_prompt_template_std = f"{question_prompt_message}{question_prompt_common_format}"
292
+ question_prompt_template_llama2 = f"<s>[INST] <<SYS>>{question_prompt_message}<</SYS>>{question_prompt_common_format}[/INST]"
293
+
294
+
295
+ # 3. 生成された質問文とベクターデータベースの要約をもとに、回答を行うchain のプロンプト
296
+ combine_prompt_message = """
297
+ You are an AI concierge who carefully answers questions from customers based on references.
298
+ Provide a specific answer in Japanese using sentences extracted from the following references.
299
+ If you do not know the answer, do not make up an answer and reply,
300
+ "誠に申し訳ございませんが、その点についてはわかりかねます".
301
+ """.replace("\n", "")
302
+
303
+ combine_prompt_common_format = """
304
+ ===
305
+ Question:
306
+ {query}
307
+ ===
308
+ Reference: {summaries}
309
+ ===
310
+ 日本語の回答: """
311
+
312
+ combine_prompt_template_std = f"{combine_prompt_message}{combine_prompt_common_format}"
313
+ combine_prompt_template_llama2 = f"<s>[INST] <<SYS>>{combine_prompt_message}<</SYS>>{combine_prompt_common_format}[/INST]"
314
+
315
+
316
+ # --------------------------------------
317
+ # ConversationSummaryBufferMemoryの要約プロンプト
318
+ # ソース → https://github.com/langchain-ai/langchain/blob/894c272a562471aadc1eb48e4a2992923533dea0/langchain/memory/prompt.py#L26-L49
319
+ # --------------------------------------
320
+ # Tokens: OpenAI 212/ Llama 214 <- In Japanese: Tokens: OpenAI 397/ Llama 297
321
+ conversation_summary_template = """
322
+ Using the example as a guide, compose a summary in English that gives an overview of the conversation by summarizing the "current summary" and the "new conversation".
323
+ ===
324
+ Example
325
+ [Current Summary] Customer asks AI what it thinks about Artificial Intelligence, AI says Artificial Intelligence is a good tool.
326
+
327
+ [New Conversation]
328
+ Human: なぜ人工知能が良いツールだと思いますか?
329
+ AI: 人工知能は「人間の可能性を最大限に引き出すことを助ける」からです。
330
+
331
+ [New Summary] Customer asks what you think about Artificial Intelligence, and AI responds that it is a good force that helps humans reach their full potential.
332
+ ===
333
+ [Current Summary] {summary}
334
+
335
+ [New Conversation]
336
+ {new_lines}
337
+
338
+ [New Summary]
339
+ """.strip()
340
+
341
+ # モデル読み込み
342
+ def load_models(
343
+ ss: SessionState,
344
+ model_id: str,
345
+ embedding_id: str,
346
+ openai_api_key: str,
347
+ load_in_8bit: bool,
348
+ verbose: bool,
349
+ temperature: float,
350
+ similarity_search_k: int,
351
+ summarization_mode: str,
352
+ min_length: int,
353
+ max_new_tokens: int,
354
+ top_k: int,
355
+ top_p: float,
356
+ repetition_penalty: float,
357
+ num_return_sequences: int,
358
+ ) -> (SessionState, str):
359
+
360
+ # --------------------------------------
361
+ # 変数の保存
362
+ # --------------------------------------
363
+ ss.similarity_search_k = similarity_search_k
364
+ ss.summarization_mode = summarization_mode
365
+
366
+ # --------------------------------------
367
+ # OpenAI API KEYの確認
368
+ # --------------------------------------
369
+ if (model_id == "gpt-3.5-turbo" or embedding_id == "text-embedding-ada-002"):
370
+ # 前処理
371
+ if not os.environ["OPENAI_API_KEY"]:
372
+ status_message = "❌ OpenAI API KEY を設定してください"
373
+ return ss, status_message
374
+
375
+ # --------------------------------------
376
+ # LLMの設定
377
+ # --------------------------------------
378
+ # OpenAI Model
379
+ if model_id == "gpt-3.5-turbo":
380
+ ss.clear_memory(llm=True, db=True)
381
+ ss.llm = ChatOpenAI(
382
+ model_name = model_id,
383
+ temperature = temperature,
384
+ verbose = verbose,
385
+ max_tokens = max_new_tokens,
386
+ )
387
+
388
+ # Hugging Face GPT Model
389
+ else:
390
+ ss.clear_memory(llm=True, db=True)
391
+
392
+ if model_id == "rinna/bilingual-gpt-neox-4b-instruction-sft":
393
+ ss.tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
394
+ else:
395
+ ss.tokenizer = AutoTokenizer.from_pretrained(model_id)
396
+
397
+ ss.model = AutoModelForCausalLM.from_pretrained(
398
+ model_id,
399
+ load_in_8bit = load_in_8bit,
400
+ torch_dtype = torch.float16,
401
+ device_map = "auto",
402
+ )
403
+
404
+ ss.pipe = pipeline(
405
+ "text-generation",
406
+ model = ss.model,
407
+ tokenizer = ss.tokenizer,
408
+ min_length = min_length,
409
+ max_new_tokens = max_new_tokens,
410
+ do_sample = True,
411
+ top_k = top_k,
412
+ top_p = top_p,
413
+ repetition_penalty = repetition_penalty,
414
+ num_return_sequences = num_return_sequences,
415
+ temperature = temperature,
416
+ )
417
+ ss.llm = HuggingFacePipeline(pipeline=ss.pipe)
418
+
419
+ # --------------------------------------
420
+ # 埋め込みモデルの設定
421
+ # --------------------------------------
422
+ if ss.current_embedding == embedding_id:
423
+ pass
424
+
425
+ else:
426
+ # Reset embeddings and vectordb
427
+ ss.clear_memory(embd=True, db=True)
428
+
429
+ if embedding_id == "None":
430
+ pass
431
+
432
+ # OpenAI
433
+ elif embedding_id == "text-embedding-ada-002":
434
+ ss.embeddings = OpenAIEmbeddings()
435
+
436
+ # Hugging Face
437
+ else:
438
+ ss.embeddings = HuggingFaceEmbeddings(model_name=embedding_id)
439
+
440
+ # --------------------------------------
441
+ # チェーンの設定
442
+ #---------------------------------------
443
+ ss = set_chains(ss, summarization_mode)
444
+
445
+ # --------------------------------------
446
+ # 現在のモデル名を SessionStateオブジェクトに保存
447
+ #---------------------------------------
448
+ ss.current_model = model_id
449
+ ss.current_embedding = embedding_id
450
+
451
+ # Status Message
452
+ status_message = "✅ LLM: " + ss.current_model + ", embeddings: " + ss.current_embedding
453
+
454
+ return ss, status_message
455
+
456
+ # --------------------------------------
457
+ # Conversation/QA Chain 呼び出し統合
458
+ # --------------------------------------
459
+ def set_chains(ss: SessionState, summarization_mode) -> SessionState:
460
+
461
+ # モデルに合わせて chat_template を設定
462
+ human_prefix = "Human: "
463
+ ai_prefix = "AI: "
464
+ chat_template = chat_template_std
465
+ qa_template = qa_template_std
466
+ query_generator_template = query_generator_template_std
467
+ question_template = question_prompt_template_std
468
+ combine_template = combine_prompt_template_std
469
+
470
+ if ss.current_model == "rinna/bilingual-gpt-neox-4b-instruction-sft":
471
+ # Rinnaモデル向けの設定(改行コード修正、メモリ用prefix (公式ページ参照)
472
+ chat_template = chat_template.replace("\n", "<NL>")
473
+ qa_template = qa_template.replace("\n", "<NL>")
474
+ query_generator_template = query_generator_template_std.replace("\n", "<NL>")
475
+ question_template = question_prompt_template_std.replace("\n", "<NL>")
476
+ combine_template = combine_prompt_template_std.replace("\n", "<NL>")
477
+ human_prefix = "ユーザー: "
478
+ ai_prefix = "システム: "
479
+
480
+ elif ss.current_model.startswith("elyza/ELYZA-japanese-Llama-2-7b"):
481
+ # ELYZAモデル向けのテンプレート設定
482
+ chat_template = chat_template_llama2
483
+ qa_template = qa_template_llama2
484
+ query_generator_template = query_generator_template_llama2
485
+ question_template = question_prompt_template_llama2
486
+ combine_template = combine_prompt_template_llama2
487
+
488
+ # --------------------------------------
489
+ # メモリの設定
490
+ # --------------------------------------
491
+ if ss.memory is None:
492
+ conversation_summary_prompt = PromptTemplate(input_variables=['summary', 'new_lines'], template=conversation_summary_template)
493
+ ss.memory = ConversationSummaryBufferMemory(
494
+ llm = ss.llm,
495
+ memory_key = "chat_history",
496
+ input_key = "query",
497
+ output_key = "output_text",
498
+ return_messages = False,
499
+ human_prefix = human_prefix,
500
+ ai_prefix = ai_prefix,
501
+ max_token_limit = 1024,
502
+ prompt = conversation_summary_prompt,
503
+ )
504
+
505
+ # --------------------------------------
506
+ # Conversation/QAチェーンの設定
507
+ # --------------------------------------
508
+ if ss.conversation_chain is None:
509
+ chat_prompt = PromptTemplate(input_variables=['query', 'chat_history'], template=chat_template)
510
+ ss.conversation_chain = ConversationChain(
511
+ llm = ss.llm,
512
+ prompt = chat_prompt,
513
+ memory = ss.memory,
514
+ input_key = "query",
515
+ output_key = "output_text",
516
+ verbose = True,
517
+ )
518
+
519
+ if ss.qa_chain is None:
520
+ if summarization_mode == "stuff":
521
+ qa_prompt = PromptTemplate(input_variables=['context', 'query', 'chat_history'], template=qa_template)
522
+ ss.qa_chain = load_qa_chain(ss.llm, chain_type="stuff", memory=ss.memory, prompt=qa_prompt)
523
+
524
+ elif summarization_mode == "map_reduce":
525
+ query_generator_prompt = PromptTemplate(template=query_generator_template, input_variables = ["chat_history", "query"])
526
+ ss.query_generator = LLMChain(llm=ss.llm, prompt=query_generator_prompt)
527
+
528
+ question_prompt = PromptTemplate(template=question_template, input_variables=["context", "query"])
529
+ combine_prompt = PromptTemplate(template=combine_template, input_variables=["summaries", "query"])
530
+ ss.qa_chain = load_qa_chain(ss.llm, chain_type="map_reduce", return_map_steps=True, memory=ss.memory, question_prompt=question_prompt, combine_prompt=combine_prompt)
531
+
532
+ return ss
533
+
534
+ def initialize_db(ss: SessionState) -> SessionState:
535
+
536
+ # client = chromadb.PersistentClient(path="./db")
537
+ ss.db = Chroma(
538
+ collection_name = "user_reference",
539
+ embedding_function = ss.embeddings,
540
+ # client = client
541
+ )
542
+
543
+ return ss
544
+
545
+ def embedding_process(ss: SessionState, ref_documents: Document) -> SessionState:
546
+
547
+ # --------------------------------------
548
+ # 文章構成と不要な文字列の削除
549
+ # --------------------------------------
550
+ for i in range(len(ref_documents)):
551
+ content = ref_documents[i].page_content.strip()
552
+
553
+ # --------------------------------------
554
+ # PDFの場合は読み取りエラー対策で文書修正を強めに実施
555
+ # --------------------------------------
556
+ if ".pdf" in ref_documents[i].metadata['source']:
557
+ pdf_replacement_sets = [
558
+ ('\n ', '**PLACEHOLDER+SPACE**'),
559
+ ('\n\u3000', '**PLACEHOLDER+SPACE**'),
560
+ ('.\n', '。**PLACEHOLDER**'),
561
+ (',\n', '。**PLACEHOLDER**'),
562
+ ('?\n', '。**PLACEHOLDER**'),
563
+ ('!\n', '。**PLACEHOLDER**'),
564
+ ('!\n', '。**PLACEHOLDER**'),
565
+ ('。\n', '。**PLACEHOLDER**'),
566
+ ('!\n', '!**PLACEHOLDER**'),
567
+ (')\n', '!**PLACEHOLDER**'),
568
+ (']\n', '!**PLACEHOLDER**'),
569
+ ('?\n', '?**PLACEHOLDER**'),
570
+ (')\n', '?**PLACEHOLDER**'),
571
+ ('】\n', '?**PLACEHOLDER**'),
572
+ ]
573
+ for original, replacement in pdf_replacement_sets:
574
+ content = content.replace(original, replacement)
575
+ content = content.replace(" ", "")
576
+ # --------------------------------------
577
+
578
+ # 不要文字列・空白の削除
579
+ remove_texts = ["\n", "\r", " "]
580
+ for remove_text in remove_texts:
581
+ content = content.replace(remove_text, "")
582
+
583
+ # タブや連続空白をシングルスペースに変換
584
+ replace_texts = ["\t", "\u3000"]
585
+ for replace_text in replace_texts:
586
+ content = content.replace(replace_text, " ")
587
+
588
+ # PDFの正当な改行をもとに戻す。
589
+ if ".pdf" in ref_documents[i].metadata['source']:
590
+ content = content.replace('**PLACEHOLDER**', '\n').replace('**PLACEHOLDER+SPACE**', '\n ')
591
+
592
+ ref_documents[i].page_content = content
593
+
594
+ # --------------------------------------
595
+ # チャンクに分割
596
+ texts = text_splitter.split_documents(ref_documents)
597
+
598
+ # --------------------------------------
599
+ # multi-e5 モデルの学習環境に合わせて文言を追加
600
+ # https://hironsan.hatenablog.com/entry/2023/07/05/073150
601
+ # --------------------------------------
602
+ if ss.current_embedding == "intfloat/multilingual-e5-large":
603
+ for i in range(len(texts)):
604
+ texts[i].page_content = "passage:" + texts[i].page_content
605
+
606
+ # vectordb の初期化
607
+ if ss.db is None:
608
+ ss = initialize_db(ss)
609
+
610
+ # db に埋め込み
611
+ # ss.db = Chroma.from_documents(texts, ss.embeddings)
612
+ ss.db.add_documents(documents=texts, embedding=ss.embeddings)
613
+
614
+ return ss
615
+
616
+ def embed_ref(ss: SessionState, urls: str, fileobj: list, header_lim: int, footer_lim: int) -> (SessionState, str):
617
+
618
+ # --------------------------------------
619
+ # モデルロード確認
620
+ # --------------------------------------
621
+ if ss.llm is None or ss.embeddings is None:
622
+ status_message = "❌ LLM/Embeddingモデルが登録されていません。"
623
+ return ss, status_message
624
+
625
+ url_flag = "-"
626
+ pdf_flag = "-"
627
+
628
+ # --------------------------------------
629
+ # URLの読み込みとvectordb登録
630
+ # --------------------------------------
631
+
632
+ # URLリストの前処理(リスト化、重複削除、非URL排除)
633
+ urls = list({url for url in urls.split("\n") if url and "://" in url})
634
+
635
+ if urls:
636
+ # 登録済みURL(ss.embedded_urls)との重複を排除。登録済みリストに登録
637
+ urls = [url for url in urls if url not in ss.embedded_urls]
638
+ ss.embedded_urls.extend(urls)
639
+
640
+ # ウェブページの読み込み
641
+ loader = SeleniumURLLoader(urls=urls)
642
+ ref_documents = loader.load()
643
+
644
+ # 埋め込み処理の実行
645
+ ss = embedding_process(ss, ref_documents)
646
+
647
+ url_flag = "✅ 登録済"
648
+
649
+ # --------------------------------------
650
+ # PDFのヘッダーとフッターを除去してvectordb登録
651
+ #  https://pypdf.readthedocs.io/en/stable/user/extract-text.html
652
+ # --------------------------------------
653
+
654
+ if fileobj is None:
655
+ pass
656
+
657
+ else:
658
+ # ファイル名リストを取得
659
+ pdf_paths = []
660
+ for path in fileobj:
661
+ pdf_paths.append(path.name)
662
+
663
+ # リストの初期化
664
+ ref_documents = []
665
+
666
+ # 各PDFファイルを読み込み
667
+ for pdf_path in pdf_paths:
668
+ pdf = PdfReader(pdf_path)
669
+ body = []
670
+
671
+ def visitor_body(text, cm, tm, font_dict, font_size):
672
+ y = tm[5]
673
+ if y > footer_lim and y < header_lim: # y座標がヘッダーとフッターの間にあるかどうかを確認
674
+ parts.append(text)
675
+
676
+ for page in pdf.pages:
677
+ parts = []
678
+ page.extract_text(visitor_text=visitor_body)
679
+ body.append("".join(parts))
680
+
681
+ body = "\n".join(body)
682
+
683
+ # パスからファイル名のみを取得
684
+ filename = os.path.basename(pdf_path)
685
+ # 取得テキスト → LangChain ドキュメント変換
686
+ ref_documents.append(Document(page_content=body, metadata={"source": filename}))
687
+
688
+ # 埋め込み処理の実行
689
+ ss = embedding_process(ss, ref_documents)
690
+
691
+ pdf_flag = "✅ 登録済"
692
+
693
+
694
+ langchain.debug=True
695
+
696
+ status_message = "URL: " + url_flag + " / PDF: " + pdf_flag
697
+ return ss, status_message
698
+
699
+ def clear_db(ss: SessionState) -> (SessionState, str):
700
+ if ss.db is None:
701
+ status_message = "❌ 参照データが登録されていません。"
702
+ return ss, status_message
703
+
704
+ try:
705
+ ss.db.delete_collection()
706
+ status_message = "✅ 参照データを削除しました。"
707
+
708
+ except NameError:
709
+ status_message = "❌ 参照データが登録されていません。"
710
+
711
+ return ss, status_message
712
+
713
+ # ----------------------------------------------------------------------------
714
+ # query入力 ▶ [def user] ▶ [ def bot ] ▶ [def show_response] ▶ チャットボット画面
715
+ # ⬇ ⬇ ⬆
716
+ # チャットボット画面 [qa_predict / conversation_predict]
717
+ # ----------------------------------------------------------------------------
718
+
719
+ def user(ss: SessionState, query) -> (SessionState, list):
720
+ # 会話履歴が一定数を超えた場合は、最初の履歴を削除する
721
+ if len(ss.dialogue) > 20:
722
+ ss.dialogue.pop(0)
723
+
724
+ ss.dialogue = ss.dialogue + [(query, None)] # 会話履歴(None はボットの回答欄=空欄)
725
+ chat_history = ss.dialogue
726
+
727
+ # チャット画面=chat_history
728
+ return ss, chat_history
729
+
730
+ def bot(ss: SessionState, query, qa_flag, web_flag, summarization_mode) -> (SessionState, str):
731
+
732
+ original_query = query
733
+
734
+ if ss.llm is None:
735
+ response = "LLMが設定されていません。設定画面で任意のモデルを選択してください。"
736
+ ss.dialogue[-1] = (ss.dialogue[-1][0], response)
737
+ return ss, ""
738
+
739
+ elif qa_flag is True and ss.embeddings is None:
740
+ response = "Embeddingモデルが設定されていません。設定画面で任意のモデルを選択してください。"
741
+ ss.dialogue[-1] = (ss.dialogue[-1][0], response)
742
+
743
+ elif qa_flag is True and ss.db is None:
744
+ response = "参照データが登録されていません。"
745
+ ss.dialogue[-1] = (ss.dialogue[-1][0], response)
746
+
747
+ # Refine query
748
+ history = ss.memory.load_memory_variables({})
749
+ if history['chat_history'] != "":
750
+ # チャット履歴からクエリをリファイン
751
+ query = ss.query_generator({"query": query, "chat_history": history})['text']
752
+
753
+ # QA Model
754
+ if qa_flag is True and ss.embeddings is not None and ss.db is not None:
755
+ if web_flag:
756
+ web_query = web_search(query, ss.current_model)
757
+ ss = qa_predict(ss, web_query)
758
+ ss.memory.chat_memory.messages[-2].content = query
759
+ else:
760
+ ss = qa_predict(ss, query) # LLMで回答を生成
761
+
762
+ # Chat Model
763
+ else:
764
+ if web_flag:
765
+ web_query = web_search(query, ss.current_model)
766
+ ss = chat_predict(ss, web_query)
767
+ ss.memory.chat_memory.messages[-2].content = query
768
+ else:
769
+ ss = chat_predict(ss, query)
770
+
771
+ # GPTモデル利用時はDeepLでメモリを英語化
772
+ ss = deepl_memory(ss)
773
+
774
+ return ss, "" # ssとquery欄(空欄)
775
+
776
+ def chat_predict(ss: SessionState, query) -> SessionState:
777
+ response = ss.conversation_chain.predict(query=query)
778
+ ss.dialogue[-1] = (ss.dialogue[-1][0], response)
779
+ return ss
780
+
781
+ def qa_predict(ss: SessionState, query) -> SessionState:
782
+
783
+ # Rinnaモデル向けの設定(クエリの改行コード修正)
784
+ if ss.current_model == "rinna/bilingual-gpt-neox-4b-instruction-sft":
785
+ query = query.strip().replace("\n", "<NL>")
786
+ else:
787
+ query = query.strip()
788
+
789
+ # multilingual-e5向けのクエリ文言prefix
790
+ if ss.current_embedding == "intfloat/multilingual-e5-large":
791
+ db_query_str = "query: " + query
792
+ else:
793
+ db_query_str = query
794
+
795
+ # DBから関連文書と出典を抽出
796
+ docs = ss.db.similarity_search(db_query_str, k=ss.similarity_search_k)
797
+ sources= "\n\n[Sources]\n" + '\n - '.join(list(set(doc.metadata['source'] for doc in docs if 'source' in doc.metadata)))
798
+
799
+ # Rinnaモデル向けの設定(抽出文書の改行コード修正)
800
+ if ss.current_model == "rinna/bilingual-gpt-neox-4b-instruction-sft":
801
+ for i in range(len(docs)):
802
+ docs[i].page_content = docs[i].page_content.strip().replace("\n", "<NL>")
803
+
804
+ # 回答の生成(最大3回の試行)
805
+ for _ in range(3):
806
+ result = ss.qa_chain({"input_documents": docs, "query": query})
807
+ result["output_text"] = result["output_text"].replace("<NL>", "\n").strip("...").strip("回答:").strip()
808
+
809
+ # result["output_text"]が空欄でない場合、メモリーを更新して返す
810
+ if result["output_text"] != "":
811
+ response = result["output_text"] + sources
812
+ ss.dialogue[-1] = (ss.dialogue[-1][0], response)
813
+ return ss
814
+ else:
815
+ # 空欄の場合は直近の履歴を削除してやり直し
816
+ ss.memory.chat_memory.messages = ss.memory.chat_memory.messages[:-2]
817
+
818
+ # 3回の試行後も空欄の場合
819
+ response = "3回試行しましたが、情報製生成できませんでした。"
820
+ if sources != "":
821
+ response += "参考文献の抽出には成功していますので、言語モデルを変えてお���しください。"
822
+
823
+ # ユーザーメッセージと AI メッセージの追加
824
+ ss.memory.chat_memory.add_user_message(query.replace("<NL>", "\n"))
825
+ ss.memory.chat_memory.add_ai_message(response)
826
+ ss.dialogue[-1] = (ss.dialogue[-1][0], response) # 会話履歴
827
+ return ss
828
+
829
+ # 回答を1文字ずつチャット画面に表示する
830
+ def show_response(ss: SessionState) -> str:
831
+ chat_history = [list(item) for item in ss.dialogue] # タプルをリストに変換して、メモリから会話履歴を取得
832
+ response = chat_history[-1][1] # メモリから最新の会話[-1]を取得し、チャットボットの回答[1]を退避
833
+ chat_history[-1][1] = "" # 逐次表示のため、チャットボットの回答[1]を空にする
834
+
835
+ if response is None:
836
+ response = "回答を生成できませんでした。"
837
+
838
+ for character in response:
839
+ chat_history[-1][1] += character
840
+ time.sleep(0.05)
841
+ yield chat_history
842
+
843
+ with gr.Blocks() as demo:
844
+
845
+ # ユーザ別セッションメモリのインスタンス化(リロードでリセット)
846
+ ss = gr.State(SessionState())
847
+
848
+ # --------------------------------------
849
+ # API KEY をセット/クリアする関数
850
+ # --------------------------------------
851
+ def openai_api_setfn(openai_api_key) -> str:
852
+ if openai_api_key == "kikagaku":
853
+ os.environ["OPENAI_API_KEY"] = os.getenv("kikagaku_demo")
854
+ status_message = "✅ キカガク専用DEMOへようこそ!APIキーを設定しました"
855
+ return status_message
856
+ elif not openai_api_key or not openai_api_key.startswith("sk-") or len(openai_api_key) < 50:
857
+ os.environ["OPENAI_API_KEY"] = ""
858
+ status_message = "❌ 有効なAPIキーを入力してください"
859
+ return status_message
860
+ else:
861
+ os.environ["OPENAI_API_KEY"] = openai_api_key
862
+ status_message = "✅ APIキーを設定しました"
863
+ return status_message
864
+
865
+ def openai_api_clsfn(ss) -> (str, str):
866
+ openai_api_key = ""
867
+ os.environ["OPENAI_API_KEY"] = ""
868
+ status_message = "✅ APIキーの削除が完了しました"
869
+ return status_message, ""
870
+
871
+ with gr.Tabs():
872
+ # --------------------------------------
873
+ # Setting Tab
874
+ # --------------------------------------
875
+ with gr.TabItem("1. LLM設定"):
876
+ with gr.Row():
877
+ model_id = gr.Dropdown(
878
+ choices=[
879
+ 'elyza/ELYZA-japanese-Llama-2-7b-fast-instruct',
880
+ 'rinna/bilingual-gpt-neox-4b-instruction-sft',
881
+ 'gpt-3.5-turbo',
882
+ ],
883
+ value="gpt-3.5-turbo",
884
+ label='LLM model',
885
+ interactive=True,
886
+ )
887
+ with gr.Row():
888
+ embedding_id = gr.Dropdown(
889
+ choices=[
890
+ 'intfloat/multilingual-e5-large',
891
+ 'sonoisa/sentence-bert-base-ja-mean-tokens-v2',
892
+ 'oshizo/sbert-jsnli-luke-japanese-base-lite',
893
+ 'text-embedding-ada-002',
894
+ "None"
895
+ ],
896
+ value="text-embedding-ada-002",
897
+ label = 'Embedding model',
898
+ interactive=True,
899
+ )
900
+ with gr.Row():
901
+ with gr.Column(scale=19):
902
+ openai_api_key = gr.Textbox(label="OpenAI API Key (Optional)", interactive=True, type="password", value="", placeholder="Your OpenAI API Key for OpenAI models.", max_lines=1)
903
+ with gr.Column(scale=1):
904
+ openai_api_set = gr.Button(value="Set API KEY", size="sm")
905
+ openai_api_cls = gr.Button(value="Delete API KEY", size="sm")
906
+
907
+ # with gr.Row():
908
+ # reference_libs = gr.CheckboxGroup(choices=['LangChain', 'Gradio'], label="Reference Libraries", interactive=False)
909
+
910
+ # 詳細設定(折りたたみ)
911
+ with gr.Accordion(label="Advanced Setting", open=False):
912
+ with gr.Row():
913
+ with gr.Column():
914
+ load_in_8bit = gr.Checkbox(label="8bit Quantize (HF)", value=True, interactive=True)
915
+ verbose = gr.Checkbox(label="Verbose (OpenAI, HF)", value=True, interactive=True)
916
+ with gr.Column():
917
+ temperature = gr.Slider(label='Temperature (OpenAI, HF)', minimum=0.0, maximum=1.0, step=0.1, value=0.2, interactive=True)
918
+ with gr.Column():
919
+ similarity_search_k = gr.Slider(label="similarity_search_k (OpenAI, HF)", minimum=1, maximum=10, step=1, value=3, interactive=True)
920
+ with gr.Column():
921
+ summarization_mode = gr.Radio(choices=['stuff', 'map_reduce'], label="Summarization mode", value='stuff', interactive=True)
922
+ with gr.Column():
923
+ min_length = gr.Slider(label="min_length (HF)", minimum=1, maximum=100, step=1, value=10, interactive=True)
924
+ with gr.Column():
925
+ max_new_tokens = gr.Slider(label="max_tokens(OpenAI), max_new_tokens(HF)", minimum=1, maximum=1024, step=1, value=256, interactive=True)
926
+ with gr.Column():
927
+ top_k = gr.Slider(label='top_k (HF)', minimum=1, maximum=100, step=1, value=40, interactive=True)
928
+ with gr.Column():
929
+ top_p = gr.Slider(label='top_p (HF)', minimum=0.01, maximum=0.99, step=0.01, value=0.92, interactive=True)
930
+ with gr.Column():
931
+ repetition_penalty = gr.Slider(label='repetition_penalty (HF)', minimum=0.5, maximum=2, step=0.1, value=1.2, interactive=True)
932
+ with gr.Column():
933
+ num_return_sequences = gr.Slider(label='num_return_sequences (HF)', minimum=1, maximum=20, step=1, value=3, interactive=True)
934
+
935
+ with gr.Row():
936
+ with gr.Column(scale=2):
937
+ config_btn = gr.Button(value="Configure")
938
+ with gr.Column(scale=13):
939
+ status_cfg = gr.Textbox(show_label=False, interactive=False, value="モデルを設定してください", container=False, max_lines=1)
940
+
941
+ # ボタン等のアクション設定
942
+ openai_api_set.click(openai_api_setfn, inputs=[openai_api_key], outputs=[status_cfg], show_progress="full")
943
+ openai_api_cls.click(openai_api_clsfn, inputs=[openai_api_key], outputs=[status_cfg, openai_api_key], show_progress="full")
944
+ openai_api_key.submit(openai_api_setfn, inputs=[openai_api_key], outputs=[status_cfg], show_progress="full")
945
+ config_btn.click(
946
+ fn = load_models,
947
+ inputs = [ss, model_id, embedding_id, openai_api_key, load_in_8bit, verbose, temperature, \
948
+ similarity_search_k, summarization_mode, \
949
+ min_length, max_new_tokens, top_k, top_p, repetition_penalty, num_return_sequences],
950
+ outputs = [ss, status_cfg],
951
+ queue = True,
952
+ show_progress = "full"
953
+ )
954
+
955
+ # --------------------------------------
956
+ # Reference Tab
957
+ # --------------------------------------
958
+ with gr.TabItem("2. References"):
959
+ urls = gr.TextArea(
960
+ max_lines = 60,
961
+ show_label=False,
962
+ info = "List any reference URLs for Q&A retrieval.",
963
+ placeholder = "https://blog.kikagaku.co.jp/deep-learning-transformer\nhttps://note.com/elyza/n/na405acaca130",
964
+ interactive=True,
965
+ )
966
+
967
+ with gr.Row():
968
+ pdf_paths = gr.File(label="PDFs", height=150, min_width=60, scale=7, file_types=[".pdf"], file_count="multiple", interactive=True)
969
+ header_lim = gr.Number(label="Header (pt)", step=1, value=792, precision=0, min_width=70, scale=1, interactive=True)
970
+ footer_lim = gr.Number(label="Footer (pt)", step=1, value=0, precision=0, min_width=70, scale=1, interactive=True)
971
+ pdf_ref = gr.Textbox(show_label=False, value="A4 Size:\n(下)0-792pt(上)\n *28.35pt/cm", container=False, scale=1, interactive=False)
972
+
973
+ with gr.Row():
974
+ ref_set_btn = gr.Button(value="コンテンツ登録", scale=1)
975
+ ref_clear_btn = gr.Button(value="登録データ削除", scale=1)
976
+ status_ref = gr.Textbox(show_label=False, interactive=False, value="参照データ未登録", container=False, max_lines=1, scale=18)
977
+
978
+ ref_set_btn.click(fn=embed_ref, inputs=[ss, urls, pdf_paths, header_lim, footer_lim], outputs=[ss, status_ref], queue=True, show_progress="full")
979
+ ref_clear_btn.click(fn=clear_db, inputs=[ss], outputs=[ss, status_ref], show_progress="full")
980
+
981
+ # --------------------------------------
982
+ # Chatbot Tab
983
+ # --------------------------------------
984
+ with gr.TabItem("3. Q&A Chat"):
985
+ chat_history = gr.Chatbot([], elem_id="chatbot", avatar_images=["bear.png", "penguin.png"],)
986
+ with gr.Row():
987
+ with gr.Column(scale=95):
988
+ query = gr.Textbox(
989
+ show_label=False,
990
+ placeholder="Send a message with [Shift]+[Enter] key.",
991
+ lines=4,
992
+ container=False,
993
+ autofocus=True,
994
+ interactive=True,
995
+ )
996
+ with gr.Column(scale=5):
997
+ with gr.Row():
998
+ qa_flag = gr.Checkbox(label="QA mode", value=True, min_width=60, interactive=True)
999
+ web_flag = gr.Checkbox(label="Web Search", value=False, min_width=60, interactive=True)
1000
+ with gr.Row():
1001
+ query_send_btn = gr.Button(value="▶")
1002
+
1003
+ # gr.Examples(["機械学習について説明してください"], inputs=[query])
1004
+ query.submit(
1005
+ user, [ss, query], [ss, chat_history]
1006
+ ).then(
1007
+ bot, [ss, query, qa_flag, web_flag, summarization_mode], [ss, query]
1008
+ ).then(
1009
+ show_response, [ss], [chat_history]
1010
+ )
1011
+
1012
+ query_send_btn.click(
1013
+ user, [ss, query], [ss, chat_history]
1014
+ ).then(
1015
+ bot, [ss, query, qa_flag, web_flag, summarization_mode], [ss, query]
1016
+ ).then(
1017
+ show_response, [ss], [chat_history]
1018
+ )
1019
+
1020
+ if __name__ == "__main__":
1021
+ demo.queue(concurrency_count=5)
1022
+ demo.launch(debug=True)
bear.png ADDED
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ chromium-driver
penguin.png ADDED
requirements.txt ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==0.22.0
2
+ beautifulsoup4==4.11.2
3
+ bitsandbytes==0.41.1
4
+ transformers==4.30.0
5
+ sentence-transformers==2.2.2
6
+ sentencepiece==0.1.99
7
+ langchain==0.0.281
8
+ xformers==0.0.21
9
+ chromadb==0.4.8
10
+ gradio==3.42.0
11
+ gradio_client==0.5.0
12
+ openai==0.28.0
13
+ tiktoken==0.4.0
14
+ fugashi==1.3.0
15
+ ipadic==1.0.0
16
+ unstructured==0.10.12
17
+ selenium==4.12.0
18
+ pypdf==3.15.5
19
+ Cython==0.29.36
20
+ numpy==1.23.5
21
+ pandas==1.5.3
22
+ chromedriver-autoinstaller
23
+ chromedriver-binary