Spaces:
Running
Running
ver 0.7 for test
Browse files- app.py +869 -0
- requirements.txt +18 -0
app.py
ADDED
@@ -0,0 +1,869 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------
|
2 |
+
# Libraries
|
3 |
+
# --------------------------------------
|
4 |
+
import os
|
5 |
+
import time
|
6 |
+
import gc # メモリ解放
|
7 |
+
import re # 正規表現で文章をクリーンアップ
|
8 |
+
|
9 |
+
# HuggingFace
|
10 |
+
import torch
|
11 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
12 |
+
|
13 |
+
# OpenAI
|
14 |
+
import openai
|
15 |
+
from langchain.embeddings.openai import OpenAIEmbeddings
|
16 |
+
from langchain.chat_models import ChatOpenAI
|
17 |
+
|
18 |
+
# LangChain
|
19 |
+
from langchain.llms import HuggingFacePipeline
|
20 |
+
from transformers import pipeline
|
21 |
+
|
22 |
+
from langchain.embeddings import HuggingFaceEmbeddings
|
23 |
+
from langchain.chains import VectorDBQA
|
24 |
+
from langchain.vectorstores import Chroma
|
25 |
+
|
26 |
+
from langchain import PromptTemplate, ConversationChain
|
27 |
+
from langchain.chains.question_answering import load_qa_chain # QA Chat
|
28 |
+
from langchain.document_loaders import SeleniumURLLoader # URL取得
|
29 |
+
from langchain.docstore.document import Document # テキストをドキュメント化
|
30 |
+
# from langchain.memory import ConversationBufferWindowMemory # チャット履歴
|
31 |
+
from langchain.memory import ConversationSummaryBufferMemory # チャット履歴
|
32 |
+
|
33 |
+
from typing import Any
|
34 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
35 |
+
|
36 |
+
# Gradio
|
37 |
+
import gradio as gr
|
38 |
+
|
39 |
+
# PyPdf
|
40 |
+
from pypdf import PdfReader
|
41 |
+
|
42 |
+
# test
|
43 |
+
import langchain # (debug=Trueにするため)
|
44 |
+
|
45 |
+
# --------------------------------------
|
46 |
+
# ユーザ別セッションの変数値を記録するクラス
|
47 |
+
# (参考)https://blog.shikoan.com/gradio-state/
|
48 |
+
# --------------------------------------
|
49 |
+
class SessionState:
|
50 |
+
def __init__(self):
|
51 |
+
# Hugging Face
|
52 |
+
self.tokenizer = None
|
53 |
+
self.pipe = None
|
54 |
+
self.model = None
|
55 |
+
|
56 |
+
# LangChain
|
57 |
+
self.llm = None
|
58 |
+
self.embeddings = None
|
59 |
+
self.current_model = ""
|
60 |
+
self.current_embedding = ""
|
61 |
+
self.db = None # Vector DB
|
62 |
+
self.memory = None # Langchain Chat Memory
|
63 |
+
self.qa_chain = None # load_qa_chain
|
64 |
+
self.conversation_chain = None # ConversationChain
|
65 |
+
self.embedded_urls = []
|
66 |
+
|
67 |
+
# Apps
|
68 |
+
self.dialogue = [] # Recent Chat History for display
|
69 |
+
|
70 |
+
# --------------------------------------
|
71 |
+
# Empty Cache
|
72 |
+
# --------------------------------------
|
73 |
+
def cache_clear(self):
|
74 |
+
if torch.cuda.is_available():
|
75 |
+
torch.cuda.empty_cache() # GPU Memory Clear
|
76 |
+
|
77 |
+
gc.collect() # CPU Memory Clear
|
78 |
+
|
79 |
+
# --------------------------------------
|
80 |
+
# Clear Models (llm: llm model, embd: embeddings, db: vectordb)
|
81 |
+
# --------------------------------------
|
82 |
+
def clear_memory(self, llm=False, embd=False, db=False):
|
83 |
+
# DB
|
84 |
+
if db and self.db:
|
85 |
+
self.db.delete_collection()
|
86 |
+
self.db = None
|
87 |
+
self.embedded_urls = []
|
88 |
+
|
89 |
+
# Embeddings model
|
90 |
+
if llm or embd:
|
91 |
+
self.embeddings = None
|
92 |
+
self.current_embedding = ""
|
93 |
+
self.qa_chain = None
|
94 |
+
|
95 |
+
# LLM model
|
96 |
+
if llm:
|
97 |
+
self.llm = None
|
98 |
+
self.pipe = None
|
99 |
+
self.model = None
|
100 |
+
self.current_model = ""
|
101 |
+
self.tokenizer = None
|
102 |
+
self.memory = None
|
103 |
+
self.chat_history = [] # ←必要性を要検証
|
104 |
+
|
105 |
+
self.cache_clear()
|
106 |
+
|
107 |
+
# --------------------------------------
|
108 |
+
# Load Chat History as a list
|
109 |
+
# --------------------------------------
|
110 |
+
def load_chat_history(self) -> list:
|
111 |
+
chat_history = []
|
112 |
+
try:
|
113 |
+
chat_memory = self.memory.load_memory_variables({})['chat_history']
|
114 |
+
except KeyError:
|
115 |
+
return chat_history
|
116 |
+
|
117 |
+
# チャット履歴をペアごとに読み取る
|
118 |
+
for i in range(0, len(chat_memory), 2):
|
119 |
+
user_message = chat_memory[i].content
|
120 |
+
ai_message = ""
|
121 |
+
if i + 1 < len(chat_memory):
|
122 |
+
ai_message = chat_memory[i + 1].content
|
123 |
+
chat_history.append([user_message, ai_message])
|
124 |
+
return chat_history
|
125 |
+
|
126 |
+
# --------------------------------------
|
127 |
+
# 自作TextSplitter(テキストをLLMのトークン数内に分割)
|
128 |
+
# (参考)https://www.sato-susumu.com/entry/2023/04/30/131338
|
129 |
+
# → 「!」、「?」、「)」、「.」、「!」、「?」、「,」などを追加
|
130 |
+
# --------------------------------------
|
131 |
+
class JPTextSplitter(RecursiveCharacterTextSplitter):
|
132 |
+
def __init__(self, **kwargs: Any):
|
133 |
+
separators = ["\n\n", "\n", "。", "!", "?", ")","、", ".", "!", "?", ",", " ", ""]
|
134 |
+
super().__init__(separators=separators, **kwargs)
|
135 |
+
|
136 |
+
# チャンクの分割
|
137 |
+
chunk_size = 512
|
138 |
+
chunk_overlap = 35
|
139 |
+
|
140 |
+
text_splitter = JPTextSplitter(
|
141 |
+
chunk_size = chunk_size, # チャンクの最大文字数
|
142 |
+
chunk_overlap = chunk_overlap, # オーバーラップの最大文字数
|
143 |
+
)
|
144 |
+
|
145 |
+
# --------------------------------------
|
146 |
+
# DeepL でメモリを翻訳しトークン数を削減(OpenAIモデル利用時)
|
147 |
+
# --------------------------------------
|
148 |
+
DEEPL_API_ENDPOINT = "https://api-free.deepl.com/v2/translate"
|
149 |
+
DEEPL_API_KEY = "YOUR_DEEPL_API_KEY"
|
150 |
+
|
151 |
+
def deepl_memory(ss: SessionState) -> (SessionState):
|
152 |
+
if ss.current_model == "gpt-3.5-turbo":
|
153 |
+
# メモリから会話履歴を取得
|
154 |
+
user_message = ss.memory.chat_memory.messages[-1][0].content
|
155 |
+
ai_message = ss.memory.chat_memory.messages[-1][1].content
|
156 |
+
text = [user_message, ai_message]
|
157 |
+
|
158 |
+
# DeepL設定
|
159 |
+
params = {
|
160 |
+
"auth_key": DEEPL_API_KEY,
|
161 |
+
"text": text,
|
162 |
+
"target_lang": "EN",
|
163 |
+
"source_lang": "JA"
|
164 |
+
}
|
165 |
+
request = requests.post(DEEPL_API_ENDPOINT, data=params)
|
166 |
+
request.raise_for_status() # 応答のステータスコードがエラーの場合は例外を発生させます。
|
167 |
+
response = request.json()
|
168 |
+
|
169 |
+
# JSONから翻訳文を取得
|
170 |
+
user_message = response["translations"][0]["text"]
|
171 |
+
ai_message = response["translations"][1]["text"]
|
172 |
+
|
173 |
+
# memoryの最後の会話を削除し、翻訳文を追加
|
174 |
+
ss.memory.chat_memory.messages = ss.memory.chat_memory.messages[:-1]
|
175 |
+
ss.memory.chat_memory.add_user_message(user_message)
|
176 |
+
ss.memory.chat_memory.add_ai_message(ai_message)
|
177 |
+
|
178 |
+
return ss
|
179 |
+
|
180 |
+
# --------------------------------------
|
181 |
+
# LangChain カスタムプロンプト各種
|
182 |
+
# llama tokenizer
|
183 |
+
# https://belladoreai.github.io/llama-tokenizer-js/example-demo/build/
|
184 |
+
|
185 |
+
# OpenAI tokenizer
|
186 |
+
# https://platform.openai.com/tokenizer
|
187 |
+
# --------------------------------------
|
188 |
+
|
189 |
+
# --------------------------------------
|
190 |
+
# Conversation Chain Template
|
191 |
+
# --------------------------------------
|
192 |
+
|
193 |
+
# Tokens: OpenAI 104/ Llama 105 <- In Japanese: Tokens: OpenAI 191/ Llama 162
|
194 |
+
sys_chat_message = """
|
195 |
+
The following is a conversation between an AI concierge and a customer.
|
196 |
+
The AI understands what the customer wants to know from the conversation history and the latest question,
|
197 |
+
and gives many specific details in Japanese. If the AI does not know the answer to a question, it does not
|
198 |
+
make up an answer and says "誠に申し訳ございませんが、その点についてはわかりかねます".
|
199 |
+
""".replace("\n", "")
|
200 |
+
|
201 |
+
chat_common_format = """
|
202 |
+
===
|
203 |
+
Question: {query}
|
204 |
+
===
|
205 |
+
Conversation History:
|
206 |
+
{chat_history}
|
207 |
+
===
|
208 |
+
日本語の回答:"""
|
209 |
+
|
210 |
+
chat_template_std = f"{sys_chat_message}{chat_common_format}"
|
211 |
+
chat_template_llama2 = f"<s>[INST] <<SYS>>{sys_chat_message}<</SYS>>{chat_common_format}[/INST]"
|
212 |
+
|
213 |
+
# --------------------------------------
|
214 |
+
# QA Chain Template
|
215 |
+
# --------------------------------------
|
216 |
+
# Tokens: OpenAI 113/ Llama 111 <- In Japanese: Tokens: OpenAI 256/ Llama 225
|
217 |
+
sys_qa_message = """
|
218 |
+
You are an AI concierge who carefully answers questions from customers based on references.
|
219 |
+
You understand what the customer wants to know from the "Conversation History" and "Question",
|
220 |
+
and give a specific answer in Japanese using sentences extracted from the following references.
|
221 |
+
If you do not know the answer, do not make up an answer and reply,
|
222 |
+
"誠に申し訳ございませんが、その点についてはわかりかねます".
|
223 |
+
""".replace("\n", "")
|
224 |
+
|
225 |
+
qa_common_format = """
|
226 |
+
===
|
227 |
+
Question:
|
228 |
+
{query}
|
229 |
+
===
|
230 |
+
References:
|
231 |
+
{context}
|
232 |
+
===
|
233 |
+
Conversation History:
|
234 |
+
{chat_history}
|
235 |
+
===
|
236 |
+
日本語の回答:"""
|
237 |
+
|
238 |
+
qa_template_std = f"{sys_qa_message}{qa_common_format}"
|
239 |
+
qa_template_llama2 = f"<s>[INST] <<SYS>>{sys_qa_message}<</SYS>>{qa_common_format}[/INST]"
|
240 |
+
|
241 |
+
# --------------------------------------
|
242 |
+
# ConversationSummaryBufferMemoryの要約プロンプト
|
243 |
+
# ソース → https://github.com/langchain-ai/langchain/blob/894c272a562471aadc1eb48e4a2992923533dea0/langchain/memory/prompt.py#L26-L49
|
244 |
+
# --------------------------------------
|
245 |
+
# Tokens: OpenAI 212/ Llama 214 <- In Japanese: Tokens: OpenAI 397/ Llama 297
|
246 |
+
conversation_summary_template = """
|
247 |
+
Using the example as a guide, compose a summary in English that gives an overview of the conversation by summarizing the "current summary" and the "new conversation".
|
248 |
+
===
|
249 |
+
Example
|
250 |
+
[Current Summary] Customer asks AI what it thinks about Artificial Intelligence, AI says Artificial Intelligence is a good tool.
|
251 |
+
|
252 |
+
[New Conversation]
|
253 |
+
Human: なぜ人工知能が良いツールだと思いますか?
|
254 |
+
AI: 人工知能は「人間の可能性を最大限に引き出すことを助ける」からです。
|
255 |
+
|
256 |
+
[New Summary] Customer asks what you think about Artificial Intelligence, and AI responds that it is a good force that helps humans reach their full potential.
|
257 |
+
===
|
258 |
+
[Current Summary] {summary}
|
259 |
+
|
260 |
+
[New Conversation]
|
261 |
+
{new_lines}
|
262 |
+
|
263 |
+
[New Summary]
|
264 |
+
""".strip()
|
265 |
+
|
266 |
+
# モデル読み込み
|
267 |
+
def load_models(
|
268 |
+
ss: SessionState,
|
269 |
+
model_id: str,
|
270 |
+
embedding_id: str,
|
271 |
+
openai_api_key: str,
|
272 |
+
load_in_8bit: bool,
|
273 |
+
verbose: bool,
|
274 |
+
temperature: float,
|
275 |
+
min_length: int,
|
276 |
+
max_new_tokens: int,
|
277 |
+
top_k: int,
|
278 |
+
top_p: float,
|
279 |
+
repetition_penalty: float,
|
280 |
+
num_return_sequences: int,
|
281 |
+
) -> (SessionState, str):
|
282 |
+
|
283 |
+
# --------------------------------------
|
284 |
+
# OpenAI API KEYの確認
|
285 |
+
# --------------------------------------
|
286 |
+
if (model_id == "gpt-3.5-turbo" or embedding_id == "text-embedding-ada-002"):
|
287 |
+
# 前処理
|
288 |
+
if not os.environ["OPENAI_API_KEY"]:
|
289 |
+
status_message = "❌ OpenAI API KEY を設定してください"
|
290 |
+
return ss, status_message
|
291 |
+
|
292 |
+
# --------------------------------------
|
293 |
+
# LLMの設定
|
294 |
+
# --------------------------------------
|
295 |
+
# OpenAI Model
|
296 |
+
if model_id == "gpt-3.5-turbo":
|
297 |
+
ss.clear_memory(llm=True, db=True)
|
298 |
+
ss.llm = ChatOpenAI(
|
299 |
+
model_name = model_id,
|
300 |
+
temperature = temperature,
|
301 |
+
verbose = verbose,
|
302 |
+
max_tokens = max_new_tokens,
|
303 |
+
)
|
304 |
+
|
305 |
+
# Hugging Face GPT Model
|
306 |
+
else:
|
307 |
+
ss.clear_memory(llm=True, db=True)
|
308 |
+
|
309 |
+
if model_id == "rinna/bilingual-gpt-neox-4b-instruction-sft":
|
310 |
+
ss.tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
|
311 |
+
else:
|
312 |
+
ss.tokenizer = AutoTokenizer.from_pretrained(model_id)
|
313 |
+
|
314 |
+
ss.model = AutoModelForCausalLM.from_pretrained(
|
315 |
+
model_id,
|
316 |
+
load_in_8bit = load_in_8bit,
|
317 |
+
torch_dtype = torch.float16,
|
318 |
+
device_map = "auto",
|
319 |
+
)
|
320 |
+
|
321 |
+
ss.pipe = pipeline(
|
322 |
+
"text-generation",
|
323 |
+
model = ss.model,
|
324 |
+
tokenizer = ss.tokenizer,
|
325 |
+
min_length = min_length,
|
326 |
+
max_new_tokens = max_new_tokens,
|
327 |
+
do_sample = True,
|
328 |
+
top_k = top_k,
|
329 |
+
top_p = top_p,
|
330 |
+
repetition_penalty = repetition_penalty,
|
331 |
+
num_return_sequences = num_return_sequences,
|
332 |
+
temperature = temperature,
|
333 |
+
)
|
334 |
+
ss.llm = HuggingFacePipeline(pipeline=ss.pipe)
|
335 |
+
|
336 |
+
# --------------------------------------
|
337 |
+
# 埋め込みモデルの設定
|
338 |
+
# --------------------------------------
|
339 |
+
if ss.current_embedding == embedding_id:
|
340 |
+
return
|
341 |
+
|
342 |
+
# Reset embeddings and vectordb
|
343 |
+
ss.clear_memory(embd=True, db=True)
|
344 |
+
|
345 |
+
if embedding_id == "None":
|
346 |
+
pass
|
347 |
+
|
348 |
+
# OpenAI
|
349 |
+
elif embedding_id == "text-embedding-ada-002":
|
350 |
+
ss.embeddings = OpenAIEmbeddings()
|
351 |
+
|
352 |
+
# Hugging Face
|
353 |
+
else:
|
354 |
+
ss.embeddings = HuggingFaceEmbeddings(model_name=embedding_id)
|
355 |
+
|
356 |
+
# --------------------------------------
|
357 |
+
# 現在のモデル名を SessionStateオブジェクトに保存
|
358 |
+
#---------------------------------------
|
359 |
+
ss.current_model = model_id
|
360 |
+
ss.current_embedding = embedding_id
|
361 |
+
|
362 |
+
# Status Message
|
363 |
+
status_message = "✅ LLM: " + ss.current_model + ", embeddings: " + ss.current_embedding
|
364 |
+
|
365 |
+
return ss, status_message
|
366 |
+
|
367 |
+
def conversation_prep(ss: SessionState) -> SessionState:
|
368 |
+
if ss.conversation_chain is None:
|
369 |
+
|
370 |
+
human_prefix = "Human: "
|
371 |
+
ai_prefix = "AI: "
|
372 |
+
chat_template = chat_template_std
|
373 |
+
|
374 |
+
if ss.current_model == "rinna/bilingual-gpt-neox-4b-instruction-sft":
|
375 |
+
# Rinnaモデル向けの設定(改行コード修正、メモリ用prefix (公式ページ参照)
|
376 |
+
chat_template = chat_template.replace("\n", "<NL>")
|
377 |
+
human_prefix = "ユーザー: "
|
378 |
+
ai_prefix = "システム: "
|
379 |
+
|
380 |
+
elif ss.current_model.startswith("elyza/ELYZA-japanese-Llama-2-7b"):
|
381 |
+
chat_template = chat_template_llama2
|
382 |
+
|
383 |
+
chat_prompt = PromptTemplate(input_variables=['query', 'chat_history'], template=chat_template)
|
384 |
+
|
385 |
+
if ss.memory is None:
|
386 |
+
conversation_summary_prompt = PromptTemplate(input_variables=['summary', 'new_lines'], template=conversation_summary_template)
|
387 |
+
ss.memory = ConversationSummaryBufferMemory(
|
388 |
+
llm = ss.llm,
|
389 |
+
memory_key = "chat_history",
|
390 |
+
input_key = "query",
|
391 |
+
output_key = "output_text",
|
392 |
+
return_messages = True,
|
393 |
+
human_prefix = human_prefix,
|
394 |
+
ai_prefix = ai_prefix,
|
395 |
+
max_token_limit = 512,
|
396 |
+
prompt = conversation_summary_prompt,
|
397 |
+
)
|
398 |
+
|
399 |
+
ss.conversation_chain = ConversationChain(
|
400 |
+
llm=ss.llm,
|
401 |
+
prompt = chat_prompt,
|
402 |
+
memory = ss.memory
|
403 |
+
)
|
404 |
+
|
405 |
+
return ss
|
406 |
+
|
407 |
+
def initialize_db(ss: SessionState) -> SessionState:
|
408 |
+
|
409 |
+
# client = chromadb.PersistentClient(path="./db")
|
410 |
+
ss.db = Chroma(
|
411 |
+
collection_name = "user_reference",
|
412 |
+
embedding_function = ss.embeddings,
|
413 |
+
# client = client
|
414 |
+
)
|
415 |
+
|
416 |
+
return ss
|
417 |
+
|
418 |
+
def embedding_process(ss: SessionState, ref_documents: Document) -> SessionState:
|
419 |
+
|
420 |
+
# --------------------------------------
|
421 |
+
# 文章構成と不要な文字列の削除
|
422 |
+
# --------------------------------------
|
423 |
+
for i in range(len(ref_documents)):
|
424 |
+
content = ref_documents[i].page_content.strip()
|
425 |
+
|
426 |
+
# --------------------------------------
|
427 |
+
# PDFの場合は読み取りエラー対策で文書修正を強めに実施
|
428 |
+
# --------------------------------------
|
429 |
+
if ".pdf" in ref_documents[i].metadata['source']:
|
430 |
+
pdf_replacement_sets = [
|
431 |
+
('\n ', '**PLACEHOLDER+SPACE**'),
|
432 |
+
('\n\u3000', '**PLACEHOLDER+SPACE**'),
|
433 |
+
('.\n', '。**PLACEHOLDER**'),
|
434 |
+
(',\n', '。**PLACEHOLDER**'),
|
435 |
+
('?\n', '。**PLACEHOLDER**'),
|
436 |
+
('!\n', '。**PLACEHOLDER**'),
|
437 |
+
('!\n', '。**PLACEHOLDER**'),
|
438 |
+
('。\n', '。**PLACEHOLDER**'),
|
439 |
+
('!\n', '!**PLACEHOLDER**'),
|
440 |
+
(')\n', '!**PLACEHOLDER**'),
|
441 |
+
(']\n', '!**PLACEHOLDER**'),
|
442 |
+
('?\n', '?**PLACEHOLDER**'),
|
443 |
+
(')\n', '?**PLACEHOLDER**'),
|
444 |
+
('】\n', '?**PLACEHOLDER**'),
|
445 |
+
]
|
446 |
+
for original, replacement in pdf_replacement_sets:
|
447 |
+
content = content.replace(original, replacement)
|
448 |
+
content = content.replace(" ", "")
|
449 |
+
# --------------------------------------
|
450 |
+
|
451 |
+
# 不要文字列・空白の削除
|
452 |
+
remove_texts = ["\n", "\r", " "]
|
453 |
+
for remove_text in remove_texts:
|
454 |
+
content = content.replace(remove_text, "")
|
455 |
+
|
456 |
+
# タブや連続空白をシングルスペースに変換
|
457 |
+
replace_texts = ["\t", "\u3000"]
|
458 |
+
for replace_text in replace_texts:
|
459 |
+
content = content.replace(replace_text, " ")
|
460 |
+
|
461 |
+
# PDFの正当な改行をもとに戻す。
|
462 |
+
if ".pdf" in ref_documents[i].metadata['source']:
|
463 |
+
content = content.replace('**PLACEHOLDER**', '\n').replace('**PLACEHOLDER+SPACE**', '\n ')
|
464 |
+
|
465 |
+
ref_documents[i].page_content = content
|
466 |
+
|
467 |
+
# --------------------------------------
|
468 |
+
# チャンクに分割
|
469 |
+
texts = text_splitter.split_documents(ref_documents)
|
470 |
+
|
471 |
+
# --------------------------------------
|
472 |
+
# multi-e5 モデルの学習環境に合わせて文言を追加
|
473 |
+
# https://hironsan.hatenablog.com/entry/2023/07/05/073150
|
474 |
+
# --------------------------------------
|
475 |
+
if ss.current_embedding == "intfloat/multilingual-e5-large":
|
476 |
+
for i in range(len(texts)):
|
477 |
+
texts[i].page_content = "passage:" + texts[i].page_content
|
478 |
+
|
479 |
+
# vectordb の初期化
|
480 |
+
if ss.db is None:
|
481 |
+
ss = initialize_db(ss)
|
482 |
+
|
483 |
+
# db に埋め込み
|
484 |
+
# ss.db = Chroma.from_documents(texts, ss.embeddings)
|
485 |
+
ss.db.add_documents(documents=texts, embedding=ss.embeddings)
|
486 |
+
|
487 |
+
# --------------------------------------
|
488 |
+
# QAチェーンの設定
|
489 |
+
# --------------------------------------
|
490 |
+
if ss.qa_chain is None:
|
491 |
+
|
492 |
+
# QAメモリ
|
493 |
+
human_prefix = "Human: "
|
494 |
+
ai_prefix = "AI: "
|
495 |
+
qa_template = qa_template_std
|
496 |
+
|
497 |
+
if ss.current_model == "rinna/bilingual-gpt-neox-4b-instruction-sft":
|
498 |
+
# Rinnaモデル向けの設定(改行コード修正、メモリ用prefix (公式ページ参照)
|
499 |
+
qa_template = qa_template.replace("\n", "<NL>")
|
500 |
+
human_prefix = "ユーザー: "
|
501 |
+
ai_prefix = "システム: "
|
502 |
+
|
503 |
+
elif ss.current_model.startswith("elyza/ELYZA-japanese-Llama-2-7b"):
|
504 |
+
qa_template = qa_template_llama2
|
505 |
+
|
506 |
+
qa_prompt = PromptTemplate(input_variables=['context', 'query', 'chat_history'], template=qa_template)
|
507 |
+
|
508 |
+
if ss.memory is None:
|
509 |
+
conversation_summary_prompt = PromptTemplate(input_variables=['summary', 'new_lines'], template=conversation_summary_template)
|
510 |
+
ss.memory = ConversationSummaryBufferMemory(
|
511 |
+
llm = ss.llm,
|
512 |
+
memory_key = "chat_history",
|
513 |
+
input_key = "query",
|
514 |
+
output_key = "output_text",
|
515 |
+
return_messages = True,
|
516 |
+
human_prefix = human_prefix,
|
517 |
+
ai_prefix = ai_prefix,
|
518 |
+
max_token_limit = 512,
|
519 |
+
prompt = conversation_summary_prompt,
|
520 |
+
)
|
521 |
+
|
522 |
+
ss.qa_chain = load_qa_chain(ss.llm, chain_type="stuff", memory=ss.memory, prompt=qa_prompt)
|
523 |
+
|
524 |
+
return ss
|
525 |
+
|
526 |
+
def embed_ref(ss: SessionState, urls: str, fileobj: list, header_lim: int, footer_lim: int) -> (SessionState, str):
|
527 |
+
|
528 |
+
url_flag = "-"
|
529 |
+
pdf_flag = "-"
|
530 |
+
|
531 |
+
# --------------------------------------
|
532 |
+
# URLの読み込みとvectordb登録
|
533 |
+
# --------------------------------------
|
534 |
+
|
535 |
+
# URLリストの前処理(リスト化、重複削除、非URL排除)
|
536 |
+
urls = list({url for url in urls.split("\n") if url and "://" in url})
|
537 |
+
|
538 |
+
if urls:
|
539 |
+
# 登録済みURL(ss.embedded_urls)との重複を排除。登録済みリストに登録
|
540 |
+
urls = [url for url in urls if url not in ss.embedded_urls]
|
541 |
+
ss.embedded_urls.extend(urls)
|
542 |
+
|
543 |
+
# ウェブページの読み込み
|
544 |
+
loader = SeleniumURLLoader(urls=urls)
|
545 |
+
ref_documents = loader.load()
|
546 |
+
|
547 |
+
# 埋め込み処理の実行
|
548 |
+
ss = embedding_process(ss, ref_documents)
|
549 |
+
|
550 |
+
url_flag = "✅ 登録済"
|
551 |
+
|
552 |
+
# --------------------------------------
|
553 |
+
# PDFのヘッダーとフッターを除去してvectordb登録
|
554 |
+
# https://pypdf.readthedocs.io/en/stable/user/extract-text.html
|
555 |
+
# --------------------------------------
|
556 |
+
|
557 |
+
if fileobj is None:
|
558 |
+
pass
|
559 |
+
|
560 |
+
else:
|
561 |
+
# ファイル名リストを取得
|
562 |
+
pdf_paths = []
|
563 |
+
for path in fileobj:
|
564 |
+
pdf_paths.append(path.name)
|
565 |
+
|
566 |
+
# リストの初期化
|
567 |
+
ref_documents = []
|
568 |
+
|
569 |
+
# 各PDFファイルを読み込み
|
570 |
+
for pdf_path in pdf_paths:
|
571 |
+
pdf = PdfReader(pdf_path)
|
572 |
+
body = []
|
573 |
+
|
574 |
+
def visitor_body(text, cm, tm, font_dict, font_size):
|
575 |
+
y = tm[5]
|
576 |
+
if y > footer_lim and y < header_lim: # y座標がヘッダーとフッターの間にあるかどうかを確認
|
577 |
+
parts.append(text)
|
578 |
+
|
579 |
+
for page in pdf.pages:
|
580 |
+
parts = []
|
581 |
+
page.extract_text(visitor_text=visitor_body)
|
582 |
+
body.append("".join(parts))
|
583 |
+
|
584 |
+
body = "\n".join(body)
|
585 |
+
|
586 |
+
# パスからファイル名のみを取得
|
587 |
+
filename = os.path.basename(pdf_path)
|
588 |
+
# 取得テキスト → LangChain ドキュメント変換
|
589 |
+
ref_documents.append(Document(page_content=body, metadata={"source": filename}))
|
590 |
+
|
591 |
+
# 埋め込み処理の実行
|
592 |
+
ss = embedding_process(ss, ref_documents)
|
593 |
+
|
594 |
+
pdf_flag = "✅ 登録済"
|
595 |
+
|
596 |
+
|
597 |
+
langchain.debug=True
|
598 |
+
|
599 |
+
status_message = "URL: " + url_flag + " / PDF: " + pdf_flag
|
600 |
+
return ss, status_message
|
601 |
+
|
602 |
+
def clear_db(ss: SessionState) -> (SessionState, str):
|
603 |
+
try:
|
604 |
+
ss.db.delete_collection()
|
605 |
+
status_message = "✅ 参照データを削除しました。"
|
606 |
+
|
607 |
+
except NameError:
|
608 |
+
status_message = "❌ 参照データが登録されていません。"
|
609 |
+
|
610 |
+
return ss, status_message
|
611 |
+
|
612 |
+
# ----------------------------------------------------------------------------
|
613 |
+
# query入力 ▶ [def user] ▶ [ def bot ] ▶ [def show_response] ▶ チャットボット画面
|
614 |
+
# ⬇ ⬇ ⬆
|
615 |
+
# チャットボット画面 [qa_predict / conversation_predict]
|
616 |
+
# ----------------------------------------------------------------------------
|
617 |
+
|
618 |
+
def user(ss: SessionState, query) -> (SessionState, list):
|
619 |
+
# 会話履歴が一定数を超えた場合は、最初の履歴を削除する
|
620 |
+
if len(ss.dialogue) > 10:
|
621 |
+
ss.dialogue.pop(0)
|
622 |
+
|
623 |
+
ss.dialogue = ss.dialogue + [(query, None)] # 会話履歴(None はボットの回答欄=空欄)
|
624 |
+
chat_history = ss.dialogue
|
625 |
+
|
626 |
+
# チャット画面=chat_history
|
627 |
+
return ss, chat_history
|
628 |
+
|
629 |
+
def bot(ss: SessionState, query, qa_flag) -> (SessionState, str):
|
630 |
+
if qa_flag is True:
|
631 |
+
ss = qa_predict(ss, query) # LLMで回答を生成
|
632 |
+
|
633 |
+
else:
|
634 |
+
ss = conversation_prep(ss)
|
635 |
+
ss = chat_predict(ss, query)
|
636 |
+
|
637 |
+
return ss, "" # ssとquery欄(空欄)
|
638 |
+
|
639 |
+
def chat_predict(ss: SessionState, query) -> SessionState:
|
640 |
+
response = ss.conversation_chain.predict(input=query)
|
641 |
+
ss.dialogue[-1] = (ss.dialogue[-1][0], response)
|
642 |
+
return ss
|
643 |
+
|
644 |
+
def qa_predict(ss: SessionState, query) -> SessionState:
|
645 |
+
|
646 |
+
# Rinnaモデル向けの設定(クエリの改行コード修正)
|
647 |
+
if ss.current_model == "rinna/bilingual-gpt-neox-4b-instruction-sft":
|
648 |
+
query = query.strip().replace("\n", "<NL>")
|
649 |
+
else:
|
650 |
+
query = query.strip()
|
651 |
+
|
652 |
+
# multilingual-e5向けのクエリ文言prefix
|
653 |
+
if ss.current_embedding == "intfloat/multilingual-e5-large":
|
654 |
+
db_query_str = "query: " + query
|
655 |
+
else:
|
656 |
+
db_query_str = query
|
657 |
+
|
658 |
+
# DBから関連文書と出典を抽出
|
659 |
+
docs = ss.db.similarity_search(db_query_str, k=2)
|
660 |
+
sources= "\n\n[Sources]\n" + '\n - '.join(list(set(doc.metadata['source'] for doc in docs if 'source' in doc.metadata)))
|
661 |
+
|
662 |
+
# Rinnaモデル向けの設定(抽出文書の改行コード修正)
|
663 |
+
if ss.current_model == "rinna/bilingual-gpt-neox-4b-instruction-sft":
|
664 |
+
for i in range(len(docs)):
|
665 |
+
docs[i].page_content = docs[i].page_content.strip().replace("\n", "<NL>")
|
666 |
+
|
667 |
+
# 回答の生成(最大3回の試行)
|
668 |
+
for _ in range(3):
|
669 |
+
result = ss.qa_chain({"input_documents": docs, "query": query})
|
670 |
+
result["output_text"] = result["output_text"].replace("<NL>", "\n").strip("...").strip("回答:").strip()
|
671 |
+
|
672 |
+
# result["output_text"]が空欄でない場合、メモリーを更新して返す
|
673 |
+
if result["output_text"] != "":
|
674 |
+
response = result["output_text"] + sources
|
675 |
+
ss.memory.chat_memory.messages = ss.memory.chat_memory.messages[:-1] # 最後の会話を削除
|
676 |
+
ss.memory.chat_memory.add_user_message(query)
|
677 |
+
ss.memory.chat_memory.add_ai_message(response)
|
678 |
+
ss.dialogue[-1] = (ss.dialogue[-1][0], response)
|
679 |
+
return ss
|
680 |
+
else:
|
681 |
+
# 空欄の場合は直近の履歴を削除してやり直し
|
682 |
+
ss.memory.chat_memory.messages = ss.memory.chat_memory.messages[:-1]
|
683 |
+
|
684 |
+
# 3回の試行後も空欄の場合
|
685 |
+
response = "3回試行しましたが、情報製生成できませんでした。"
|
686 |
+
if sources != "":
|
687 |
+
response += "参考文献の抽出には成功していますので、言語モデルを変えてお試しください。"
|
688 |
+
|
689 |
+
# ユーザーメッセージと AI メッセージの追加
|
690 |
+
ss.memory.chat_memory.add_user_message(query.replace("<NL>", "\n"))
|
691 |
+
ss.memory.chat_memory.add_ai_message(response)
|
692 |
+
ss.dialogue[-1] = (ss.dialogue[-1][0], response) # 会話履歴
|
693 |
+
return ss
|
694 |
+
|
695 |
+
# 回答を1文字ずつチャット画面に表示する
|
696 |
+
def show_response(ss: SessionState) -> str:
|
697 |
+
# chat_history = ss.load_chat_history() # メモリから会話履歴をリスト型で取得
|
698 |
+
# response = chat_history[-1][1] # メモリから最新の会話[-1]を取得し、チャットボットの回答[1]を退避
|
699 |
+
# chat_history[-1][1] = "" # 逐次表示のため、チャットボットの回答[1]を空にする
|
700 |
+
|
701 |
+
chat_history = [list(item) for item in ss.dialogue] # タプルをリストに変換して、メモリから会話履歴を取得
|
702 |
+
response = chat_history[-1][1] # メモリから最新の会話[-1]を取得し、チャットボットの回答[1]を退避
|
703 |
+
chat_history[-1][1] = "" # 逐次表示のため、チャットボットの回答[1]を空にする
|
704 |
+
|
705 |
+
for character in response:
|
706 |
+
chat_history[-1][1] += character
|
707 |
+
time.sleep(0.05)
|
708 |
+
yield chat_history
|
709 |
+
|
710 |
+
with gr.Blocks() as demo:
|
711 |
+
|
712 |
+
# ユーザ別セッションメモリのインスタンス化(リロードでリセット)
|
713 |
+
ss = gr.State(SessionState())
|
714 |
+
|
715 |
+
# --------------------------------------
|
716 |
+
# API KEY をセット/クリアする関数
|
717 |
+
# --------------------------------------
|
718 |
+
def openai_api_setfn(openai_api_key) -> str:
|
719 |
+
if not openai_api_key or not openai_api_key.startswith("sk-") or len(openai_api_key) < 50:
|
720 |
+
os.environ["OPENAI_API_KEY"] = ""
|
721 |
+
status_message = "❌ 有効なAPIキーを入力してください"
|
722 |
+
return status_message
|
723 |
+
else:
|
724 |
+
os.environ["OPENAI_API_KEY"] = openai_api_key
|
725 |
+
status_message = "✅ APIキーを設定しました"
|
726 |
+
return status_message
|
727 |
+
|
728 |
+
def openai_api_clsfn(ss) -> (str, str):
|
729 |
+
openai_api_key = ""
|
730 |
+
os.environ["OPENAI_API_KEY"] = ""
|
731 |
+
status_message = "✅ APIキーの削除が完了しました"
|
732 |
+
return status_message, ""
|
733 |
+
|
734 |
+
# --------------------------------------
|
735 |
+
# 回答の継続ボタン
|
736 |
+
# --------------------------------------
|
737 |
+
def continue_pred():
|
738 |
+
query = "回答を続けてください"
|
739 |
+
return query
|
740 |
+
|
741 |
+
with gr.Tabs():
|
742 |
+
# --------------------------------------
|
743 |
+
# Setting Tab
|
744 |
+
# --------------------------------------
|
745 |
+
with gr.TabItem("1. LLM設定"):
|
746 |
+
with gr.Row():
|
747 |
+
model_id = gr.Dropdown(
|
748 |
+
choices=[
|
749 |
+
'elyza/ELYZA-japanese-Llama-2-7b-fast-instruct',
|
750 |
+
'rinna/bilingual-gpt-neox-4b-instruction-sft',
|
751 |
+
'gpt-3.5-turbo',
|
752 |
+
],
|
753 |
+
value="elyza/ELYZA-japanese-Llama-2-7b-fast-instruct",
|
754 |
+
label='LLM model',
|
755 |
+
interactive=True,
|
756 |
+
)
|
757 |
+
with gr.Row():
|
758 |
+
embedding_id = gr.Dropdown(
|
759 |
+
choices=[
|
760 |
+
'intfloat/multilingual-e5-large',
|
761 |
+
'sonoisa/sentence-bert-base-ja-mean-tokens-v2',
|
762 |
+
'oshizo/sbert-jsnli-luke-japanese-base-lite',
|
763 |
+
'text-embedding-ada-002',
|
764 |
+
"None"
|
765 |
+
],
|
766 |
+
value="sonoisa/sentence-bert-base-ja-mean-tokens-v2",
|
767 |
+
label = 'Embedding model',
|
768 |
+
interactive=True,
|
769 |
+
)
|
770 |
+
with gr.Row():
|
771 |
+
with gr.Column(scale=19):
|
772 |
+
openai_api_key = gr.Textbox(label="OpenAI API Key (Optional)", interactive=True, type="password", value="", placeholder="Your OpenAI API Key for OpenAI models.", max_lines=1)
|
773 |
+
with gr.Column(scale=1):
|
774 |
+
openai_api_set = gr.Button(value="Set API KEY", size="sm")
|
775 |
+
openai_api_cls = gr.Button(value="Delete API KEY", size="sm")
|
776 |
+
|
777 |
+
# 詳細設定(折りたたみ)
|
778 |
+
with gr.Accordion(label="Advanced Setting", open=False):
|
779 |
+
with gr.Row():
|
780 |
+
with gr.Column():
|
781 |
+
load_in_8bit = gr.Checkbox(label="8bit Quantize (HF)", value=True, interactive=True)
|
782 |
+
verbose = gr.Checkbox(label="Verbose (OpenAI, HF)", value=True, interactive=False)
|
783 |
+
with gr.Column():
|
784 |
+
temperature = gr.Slider(label='Temperature (OpenAI, HF)', minimum=0.0, maximum=1.0, step=0.1, value=0.2, interactive=True)
|
785 |
+
with gr.Column():
|
786 |
+
min_length = gr.Slider(label="min_length (HF)", minimum=1, maximum=100, step=1, value=10, interactive=True)
|
787 |
+
with gr.Column():
|
788 |
+
max_new_tokens = gr.Slider(label="max_tokens(OpenAI), max_new_tokens(HF)", minimum=1, maximum=1024, step=1, value=256, interactive=True)
|
789 |
+
with gr.Column():
|
790 |
+
top_k = gr.Slider(label='top_k (HF)', minimum=1, maximum=100, step=1, value=40, interactive=True)
|
791 |
+
with gr.Column():
|
792 |
+
top_p = gr.Slider(label='top_p (HF)', minimum=0.01, maximum=0.99, step=0.01, value=0.92, interactive=True)
|
793 |
+
with gr.Column():
|
794 |
+
repetition_penalty = gr.Slider(label='repetition_penalty (HF)', minimum=0.5, maximum=2, step=0.1, value=1.2, interactive=True)
|
795 |
+
with gr.Column():
|
796 |
+
num_return_sequences = gr.Slider(label='num_return_sequences (HF)', minimum=1, maximum=20, step=1, value=3, interactive=True)
|
797 |
+
|
798 |
+
with gr.Row():
|
799 |
+
with gr.Column(scale=2):
|
800 |
+
config_btn = gr.Button(value="Configure")
|
801 |
+
with gr.Column(scale=13):
|
802 |
+
status_cfg = gr.Textbox(show_label=False, interactive=False, value="モデルを設定してください", container=False, max_lines=1)
|
803 |
+
|
804 |
+
# ボタン等のアクション設定
|
805 |
+
openai_api_set.click(openai_api_setfn, inputs=[openai_api_key], outputs=[status_cfg], show_progress="full")
|
806 |
+
openai_api_cls.click(openai_api_clsfn, inputs=[openai_api_key], outputs=[status_cfg, openai_api_key], show_progress="full")
|
807 |
+
openai_api_key.submit(openai_api_setfn, inputs=[openai_api_key], outputs=[status_cfg], show_progress="full")
|
808 |
+
config_btn.click(
|
809 |
+
fn = load_models,
|
810 |
+
inputs = [ss, model_id, embedding_id, openai_api_key, load_in_8bit, verbose, temperature,
|
811 |
+
min_length, max_new_tokens, top_k, top_p, repetition_penalty, num_return_sequences],
|
812 |
+
outputs = [ss, status_cfg],
|
813 |
+
queue = True,
|
814 |
+
show_progress = "full"
|
815 |
+
)
|
816 |
+
|
817 |
+
# --------------------------------------
|
818 |
+
# Reference Tab
|
819 |
+
# --------------------------------------
|
820 |
+
with gr.TabItem("2. References"):
|
821 |
+
urls = gr.TextArea(
|
822 |
+
max_lines = 60,
|
823 |
+
show_label=False,
|
824 |
+
info = "List any reference URLs for Q&A retrieval.",
|
825 |
+
placeholder = "https://blog.kikagaku.co.jp/deep-learning-transformer\nhttps://note.com/elyza/n/na405acaca130",
|
826 |
+
interactive=True,
|
827 |
+
)
|
828 |
+
|
829 |
+
with gr.Row():
|
830 |
+
pdf_paths = gr.File(label="PDFs", height=150, min_width=60, scale=7, file_types=[".pdf"], file_count="multiple", interactive=True)
|
831 |
+
header_lim = gr.Number(label="Header (pt)", step=1, value=792, precision=0, min_width=70, scale=1, interactive=True)
|
832 |
+
footer_lim = gr.Number(label="Footer (pt)", step=1, value=0, precision=0, min_width=70, scale=1, interactive=True)
|
833 |
+
pdf_ref = gr.Textbox(show_label=False, value="A4 Size:\n(下)0-792pt(上)\n *28.35pt/cm", container=False, scale=1, interactive=False)
|
834 |
+
|
835 |
+
with gr.Row():
|
836 |
+
ref_set_btn = gr.Button(value="コンテンツ登録", scale=1)
|
837 |
+
ref_clear_btn = gr.Button(value="登録データ削除", scale=1)
|
838 |
+
status_ref = gr.Textbox(show_label=False, interactive=False, value="参照データ未登録", container=False, max_lines=1, scale=18)
|
839 |
+
|
840 |
+
ref_set_btn.click(fn=embed_ref, inputs=[ss, urls, pdf_paths, header_lim, footer_lim], outputs=[ss, status_ref], queue=True, show_progress="full")
|
841 |
+
ref_clear_btn.click(fn=clear_db, inputs=[ss], outputs=[ss, status_ref], show_progress="full")
|
842 |
+
|
843 |
+
# --------------------------------------
|
844 |
+
# Chatbot Tab
|
845 |
+
# --------------------------------------
|
846 |
+
with gr.TabItem("3. Q&A Chat"):
|
847 |
+
chat_history = gr.Chatbot([], elem_id="chatbot").style(height=600, color_map=('green', 'gray'))
|
848 |
+
with gr.Row():
|
849 |
+
with gr.Column(scale=95):
|
850 |
+
query = gr.Textbox(
|
851 |
+
show_label=False,
|
852 |
+
placeholder="Send a message with [Shift]+[Enter] key.",
|
853 |
+
lines=4,
|
854 |
+
container=False,
|
855 |
+
autofocus=True,
|
856 |
+
interactive=True,
|
857 |
+
)
|
858 |
+
with gr.Column(scale=5):
|
859 |
+
qa_flag = gr.Checkbox(label="QA mode", value=True, min_width=60, interactive=True)
|
860 |
+
query_send_btn = gr.Button(value="▶")
|
861 |
+
|
862 |
+
# gr.Examples(["機械学習について説明してください"], inputs=[query])
|
863 |
+
query.submit(user, [ss, query], [ss, chat_history]).then(bot, [ss, query, qa_flag], [ss, query]).then(show_response, [ss], [chat_history])
|
864 |
+
query_send_btn.click(user, [ss, query], [ss, chat_history]).then(bot, [ss, query, qa_flag], [ss, query]).then(show_response, [ss], [chat_history])
|
865 |
+
|
866 |
+
if __name__ == "__main__":
|
867 |
+
demo.queue(concurrency_count=5)
|
868 |
+
demo.launch(debug=True, inbrowser=True)
|
869 |
+
|
requirements.txt
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
accelerate
|
2 |
+
bitsandbytes
|
3 |
+
transformers
|
4 |
+
sentence_transformers
|
5 |
+
sentencepiece
|
6 |
+
accelerate
|
7 |
+
bitsandbytes
|
8 |
+
langchain
|
9 |
+
xformers
|
10 |
+
chromadb
|
11 |
+
gradio
|
12 |
+
openai
|
13 |
+
tiktoken
|
14 |
+
fugashi
|
15 |
+
ipadic
|
16 |
+
unstructured
|
17 |
+
selenium
|
18 |
+
pypdf
|