Beuys commited on
Commit
e929305
·
1 Parent(s): f7c8212

add chains

Browse files
chains/__pycache__/local_doc_qa.cpython-39.pyc ADDED
Binary file (11.5 kB). View file
 
chains/dialogue_answering/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from .base import (
2
+ DialogueWithSharedMemoryChains
3
+ )
4
+
5
+ __all__ = [
6
+ "DialogueWithSharedMemoryChains"
7
+ ]
chains/dialogue_answering/__main__.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ import argparse
4
+ import asyncio
5
+ from argparse import Namespace
6
+ sys.path.append(os.path.dirname(os.path.abspath(__file__)) + '/../../')
7
+ from chains.dialogue_answering import *
8
+ from langchain.llms import OpenAI
9
+ from models.base import (BaseAnswer,
10
+ AnswerResult)
11
+ import models.shared as shared
12
+ from models.loader.args import parser
13
+ from models.loader import LoaderCheckPoint
14
+
15
+ async def dispatch(args: Namespace):
16
+
17
+ args_dict = vars(args)
18
+ shared.loaderCheckPoint = LoaderCheckPoint(args_dict)
19
+ llm_model_ins = shared.loaderLLM()
20
+ if not os.path.isfile(args.dialogue_path):
21
+ raise FileNotFoundError(f'Invalid dialogue file path for demo mode: "{args.dialogue_path}"')
22
+ llm = OpenAI(temperature=0)
23
+ dialogue_instance = DialogueWithSharedMemoryChains(zero_shot_react_llm=llm, ask_llm=llm_model_ins, params=args_dict)
24
+
25
+ dialogue_instance.agent_chain.run(input="What did David say before, summarize it")
26
+
27
+
28
+ if __name__ == '__main__':
29
+
30
+ parser.add_argument('--dialogue-path', default='', type=str, help='dialogue-path')
31
+ parser.add_argument('--embedding-model', default='', type=str, help='embedding-model')
32
+ args = parser.parse_args(['--dialogue-path', '/home/dmeck/Downloads/log.txt',
33
+ '--embedding-mode', '/media/checkpoint/text2vec-large-chinese/'])
34
+ loop = asyncio.new_event_loop()
35
+ asyncio.set_event_loop(loop)
36
+ loop.run_until_complete(dispatch(args))
chains/dialogue_answering/base.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.base_language import BaseLanguageModel
2
+ from langchain.agents import ZeroShotAgent, Tool, AgentExecutor
3
+ from langchain.memory import ConversationBufferMemory, ReadOnlySharedMemory
4
+ from langchain.chains import LLMChain, RetrievalQA
5
+ from langchain.embeddings.huggingface import HuggingFaceEmbeddings
6
+ from langchain.prompts import PromptTemplate
7
+ from langchain.text_splitter import CharacterTextSplitter
8
+ from langchain.vectorstores import Chroma
9
+
10
+ from loader import DialogueLoader
11
+ from chains.dialogue_answering.prompts import (
12
+ DIALOGUE_PREFIX,
13
+ DIALOGUE_SUFFIX,
14
+ SUMMARY_PROMPT
15
+ )
16
+
17
+
18
+ class DialogueWithSharedMemoryChains:
19
+ zero_shot_react_llm: BaseLanguageModel = None
20
+ ask_llm: BaseLanguageModel = None
21
+ embeddings: HuggingFaceEmbeddings = None
22
+ embedding_model: str = None
23
+ vector_search_top_k: int = 6
24
+ dialogue_path: str = None
25
+ dialogue_loader: DialogueLoader = None
26
+ device: str = None
27
+
28
+ def __init__(self, zero_shot_react_llm: BaseLanguageModel = None, ask_llm: BaseLanguageModel = None,
29
+ params: dict = None):
30
+ self.zero_shot_react_llm = zero_shot_react_llm
31
+ self.ask_llm = ask_llm
32
+ params = params or {}
33
+ self.embedding_model = params.get('embedding_model', 'GanymedeNil/text2vec-large-chinese')
34
+ self.vector_search_top_k = params.get('vector_search_top_k', 6)
35
+ self.dialogue_path = params.get('dialogue_path', '')
36
+ self.device = 'cuda' if params.get('use_cuda', False) else 'cpu'
37
+
38
+ self.dialogue_loader = DialogueLoader(self.dialogue_path)
39
+ self._init_cfg()
40
+ self._init_state_of_history()
41
+ self.memory_chain, self.memory = self._agents_answer()
42
+ self.agent_chain = self._create_agent_chain()
43
+
44
+ def _init_cfg(self):
45
+ model_kwargs = {
46
+ 'device': self.device
47
+ }
48
+ self.embeddings = HuggingFaceEmbeddings(model_name=self.embedding_model, model_kwargs=model_kwargs)
49
+
50
+ def _init_state_of_history(self):
51
+ documents = self.dialogue_loader.load()
52
+ text_splitter = CharacterTextSplitter(chunk_size=3, chunk_overlap=1)
53
+ texts = text_splitter.split_documents(documents)
54
+ docsearch = Chroma.from_documents(texts, self.embeddings, collection_name="state-of-history")
55
+ self.state_of_history = RetrievalQA.from_chain_type(llm=self.ask_llm, chain_type="stuff",
56
+ retriever=docsearch.as_retriever())
57
+
58
+ def _agents_answer(self):
59
+
60
+ memory = ConversationBufferMemory(memory_key="chat_history")
61
+ readonly_memory = ReadOnlySharedMemory(memory=memory)
62
+ memory_chain = LLMChain(
63
+ llm=self.ask_llm,
64
+ prompt=SUMMARY_PROMPT,
65
+ verbose=True,
66
+ memory=readonly_memory, # use the read-only memory to prevent the tool from modifying the memory
67
+ )
68
+ return memory_chain, memory
69
+
70
+ def _create_agent_chain(self):
71
+ dialogue_participants = self.dialogue_loader.dialogue.participants_to_export()
72
+ tools = [
73
+ Tool(
74
+ name="State of Dialogue History System",
75
+ func=self.state_of_history.run,
76
+ description=f"Dialogue with {dialogue_participants} - The answers in this section are very useful "
77
+ f"when searching for chat content between {dialogue_participants}. Input should be a "
78
+ f"complete question. "
79
+ ),
80
+ Tool(
81
+ name="Summary",
82
+ func=self.memory_chain.run,
83
+ description="useful for when you summarize a conversation. The input to this tool should be a string, "
84
+ "representing who will read this summary. "
85
+ )
86
+ ]
87
+
88
+ prompt = ZeroShotAgent.create_prompt(
89
+ tools,
90
+ prefix=DIALOGUE_PREFIX,
91
+ suffix=DIALOGUE_SUFFIX,
92
+ input_variables=["input", "chat_history", "agent_scratchpad"]
93
+ )
94
+
95
+ llm_chain = LLMChain(llm=self.zero_shot_react_llm, prompt=prompt)
96
+ agent = ZeroShotAgent(llm_chain=llm_chain, tools=tools, verbose=True)
97
+ agent_chain = AgentExecutor.from_agent_and_tools(agent=agent, tools=tools, verbose=True, memory=self.memory)
98
+
99
+ return agent_chain
chains/dialogue_answering/prompts.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.prompts.prompt import PromptTemplate
2
+
3
+
4
+ SUMMARY_TEMPLATE = """This is a conversation between a human and a bot:
5
+
6
+ {chat_history}
7
+
8
+ Write a summary of the conversation for {input}:
9
+ """
10
+
11
+ SUMMARY_PROMPT = PromptTemplate(
12
+ input_variables=["input", "chat_history"],
13
+ template=SUMMARY_TEMPLATE
14
+ )
15
+
16
+ DIALOGUE_PREFIX = """Have a conversation with a human,Analyze the content of the conversation.
17
+ You have access to the following tools: """
18
+ DIALOGUE_SUFFIX = """Begin!
19
+
20
+ {chat_history}
21
+ Question: {input}
22
+ {agent_scratchpad}"""
chains/local_doc_qa.py ADDED
@@ -0,0 +1,364 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.embeddings.huggingface import HuggingFaceEmbeddings
2
+ from vectorstores import MyFAISS
3
+ from langchain.document_loaders import UnstructuredFileLoader, TextLoader, CSVLoader
4
+ from configs.model_config import *
5
+ import datetime
6
+ from textsplitter import ChineseTextSplitter
7
+ from typing import List
8
+ from utils import torch_gc
9
+ from tqdm import tqdm
10
+ from pypinyin import lazy_pinyin
11
+ from models.base import (BaseAnswer,
12
+ AnswerResult)
13
+ from models.loader.args import parser
14
+ from models.loader import LoaderCheckPoint
15
+ import models.shared as shared
16
+ from agent import bing_search
17
+ from langchain.docstore.document import Document
18
+ from functools import lru_cache
19
+ from textsplitter.zh_title_enhance import zh_title_enhance
20
+ from langchain.chains.base import Chain
21
+
22
+
23
+ # patch HuggingFaceEmbeddings to make it hashable
24
+ def _embeddings_hash(self):
25
+ return hash(self.model_name)
26
+
27
+
28
+ HuggingFaceEmbeddings.__hash__ = _embeddings_hash
29
+
30
+
31
+ # will keep CACHED_VS_NUM of vector store caches
32
+ @lru_cache(CACHED_VS_NUM)
33
+ def load_vector_store(vs_path, embeddings):
34
+ return MyFAISS.load_local(vs_path, embeddings)
35
+
36
+
37
+ def tree(filepath, ignore_dir_names=None, ignore_file_names=None):
38
+ """返回两个列表,第一个列表为 filepath 下全部文件的完整路径, 第二个为对应的文件名"""
39
+ if ignore_dir_names is None:
40
+ ignore_dir_names = []
41
+ if ignore_file_names is None:
42
+ ignore_file_names = []
43
+ ret_list = []
44
+ if isinstance(filepath, str):
45
+ if not os.path.exists(filepath):
46
+ print("路径不存在")
47
+ return None, None
48
+ elif os.path.isfile(filepath) and os.path.basename(filepath) not in ignore_file_names:
49
+ return [filepath], [os.path.basename(filepath)]
50
+ elif os.path.isdir(filepath) and os.path.basename(filepath) not in ignore_dir_names:
51
+ for file in os.listdir(filepath):
52
+ fullfilepath = os.path.join(filepath, file)
53
+ if os.path.isfile(fullfilepath) and os.path.basename(fullfilepath) not in ignore_file_names:
54
+ ret_list.append(fullfilepath)
55
+ if os.path.isdir(fullfilepath) and os.path.basename(fullfilepath) not in ignore_dir_names:
56
+ ret_list.extend(tree(fullfilepath, ignore_dir_names, ignore_file_names)[0])
57
+ return ret_list, [os.path.basename(p) for p in ret_list]
58
+
59
+
60
+ def load_file(filepath, sentence_size=SENTENCE_SIZE, using_zh_title_enhance=ZH_TITLE_ENHANCE):
61
+
62
+ if filepath.lower().endswith(".md"):
63
+ loader = UnstructuredFileLoader(filepath, mode="elements")
64
+ docs = loader.load()
65
+ elif filepath.lower().endswith(".txt"):
66
+ loader = TextLoader(filepath, autodetect_encoding=True)
67
+ textsplitter = ChineseTextSplitter(pdf=False, sentence_size=sentence_size)
68
+ docs = loader.load_and_split(textsplitter)
69
+ elif filepath.lower().endswith(".pdf"):
70
+ # 暂且将paddle相关的loader改为动态加载,可以在不上传pdf/image知识文件的前提下使用protobuf=4.x
71
+ from loader import UnstructuredPaddlePDFLoader
72
+ loader = UnstructuredPaddlePDFLoader(filepath)
73
+ textsplitter = ChineseTextSplitter(pdf=True, sentence_size=sentence_size)
74
+ docs = loader.load_and_split(textsplitter)
75
+ elif filepath.lower().endswith(".jpg") or filepath.lower().endswith(".png"):
76
+ # 暂且将paddle相关的loader改为动态加载,可以在不上传pdf/image知识文件的前提下使用protobuf=4.x
77
+ from loader import UnstructuredPaddleImageLoader
78
+ loader = UnstructuredPaddleImageLoader(filepath, mode="elements")
79
+ textsplitter = ChineseTextSplitter(pdf=False, sentence_size=sentence_size)
80
+ docs = loader.load_and_split(text_splitter=textsplitter)
81
+ elif filepath.lower().endswith(".csv"):
82
+ loader = CSVLoader(filepath)
83
+ docs = loader.load()
84
+ else:
85
+ loader = UnstructuredFileLoader(filepath, mode="elements")
86
+ textsplitter = ChineseTextSplitter(pdf=False, sentence_size=sentence_size)
87
+ docs = loader.load_and_split(text_splitter=textsplitter)
88
+ if using_zh_title_enhance:
89
+ docs = zh_title_enhance(docs)
90
+ write_check_file(filepath, docs)
91
+ return docs
92
+
93
+
94
+ def write_check_file(filepath, docs):
95
+ folder_path = os.path.join(os.path.dirname(filepath), "tmp_files")
96
+ if not os.path.exists(folder_path):
97
+ os.makedirs(folder_path)
98
+ fp = os.path.join(folder_path, 'load_file.txt')
99
+ with open(fp, 'a+', encoding='utf-8') as fout:
100
+ fout.write("filepath=%s,len=%s" % (filepath, len(docs)))
101
+ fout.write('\n')
102
+ for i in docs:
103
+ fout.write(str(i))
104
+ fout.write('\n')
105
+ fout.close()
106
+
107
+
108
+ def generate_prompt(related_docs: List[str],
109
+ query: str,
110
+ prompt_template: str = PROMPT_TEMPLATE, ) -> str:
111
+ context = "\n".join([doc.page_content for doc in related_docs])
112
+ prompt = prompt_template.replace("{question}", query).replace("{context}", context)
113
+ return prompt
114
+
115
+
116
+ def search_result2docs(search_results):
117
+ docs = []
118
+ for result in search_results:
119
+ doc = Document(page_content=result["snippet"] if "snippet" in result.keys() else "",
120
+ metadata={"source": result["link"] if "link" in result.keys() else "",
121
+ "filename": result["title"] if "title" in result.keys() else ""})
122
+ docs.append(doc)
123
+ return docs
124
+
125
+
126
+ class LocalDocQA:
127
+ llm_model_chain: Chain = None
128
+ embeddings: object = None
129
+ top_k: int = VECTOR_SEARCH_TOP_K
130
+ chunk_size: int = CHUNK_SIZE
131
+ chunk_conent: bool = True
132
+ score_threshold: int = VECTOR_SEARCH_SCORE_THRESHOLD
133
+
134
+ def init_cfg(self,
135
+ embedding_model: str = EMBEDDING_MODEL,
136
+ embedding_device=EMBEDDING_DEVICE,
137
+ llm_model: Chain = None,
138
+ top_k=VECTOR_SEARCH_TOP_K,
139
+ ):
140
+ self.llm_model_chain = llm_model
141
+ self.embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[embedding_model],
142
+ model_kwargs={'device': embedding_device})
143
+ self.top_k = top_k
144
+
145
+ def init_knowledge_vector_store(self,
146
+ filepath: str or List[str],
147
+ vs_path: str or os.PathLike = None,
148
+ sentence_size=SENTENCE_SIZE):
149
+ loaded_files = []
150
+ failed_files = []
151
+ if isinstance(filepath, str):
152
+ if not os.path.exists(filepath):
153
+ print("路径不存在")
154
+ return None
155
+ elif os.path.isfile(filepath):
156
+ file = os.path.split(filepath)[-1]
157
+ try:
158
+ docs = load_file(filepath, sentence_size)
159
+ logger.info(f"{file} 已成功加载")
160
+ loaded_files.append(filepath)
161
+ except Exception as e:
162
+ logger.error(e)
163
+ logger.info(f"{file} 未能成功加载")
164
+ return None
165
+ elif os.path.isdir(filepath):
166
+ docs = []
167
+ for fullfilepath, file in tqdm(zip(*tree(filepath, ignore_dir_names=['tmp_files'])), desc="加载文件"):
168
+ try:
169
+ docs += load_file(fullfilepath, sentence_size)
170
+ loaded_files.append(fullfilepath)
171
+ except Exception as e:
172
+ logger.error(e)
173
+ failed_files.append(file)
174
+
175
+ if len(failed_files) > 0:
176
+ logger.info("以下文件未能成功加载:")
177
+ for file in failed_files:
178
+ logger.info(f"{file}\n")
179
+
180
+ else:
181
+ docs = []
182
+ for file in filepath:
183
+ try:
184
+ docs += load_file(file)
185
+ logger.info(f"{file} 已成功加载")
186
+ loaded_files.append(file)
187
+ except Exception as e:
188
+ logger.error(e)
189
+ logger.info(f"{file} 未能成功加载")
190
+ if len(docs) > 0:
191
+ logger.info("文件加载完毕,正在生成向量库")
192
+ if vs_path and os.path.isdir(vs_path) and "index.faiss" in os.listdir(vs_path):
193
+ vector_store = load_vector_store(vs_path, self.embeddings)
194
+ vector_store.add_documents(docs)
195
+ torch_gc()
196
+ else:
197
+ if not vs_path:
198
+ vs_path = os.path.join(KB_ROOT_PATH,
199
+ f"""{"".join(lazy_pinyin(os.path.splitext(file)[0]))}_FAISS_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}""",
200
+ "vector_store")
201
+ vector_store = MyFAISS.from_documents(docs, self.embeddings) # docs 为Document列表
202
+ torch_gc()
203
+
204
+ vector_store.save_local(vs_path)
205
+ return vs_path, loaded_files
206
+ else:
207
+ logger.info("文件均未成功加载,请检查依赖包或替换为其他文件再次上传。")
208
+
209
+ return None, loaded_files
210
+
211
+ def one_knowledge_add(self, vs_path, one_title, one_conent, one_content_segmentation, sentence_size):
212
+ try:
213
+ if not vs_path or not one_title or not one_conent:
214
+ logger.info("知识库添加错误,请确认知识库名字、标题、内容是否正确!")
215
+ return None, [one_title]
216
+ docs = [Document(page_content=one_conent + "\n", metadata={"source": one_title})]
217
+ if not one_content_segmentation:
218
+ text_splitter = ChineseTextSplitter(pdf=False, sentence_size=sentence_size)
219
+ docs = text_splitter.split_documents(docs)
220
+ if os.path.isdir(vs_path) and os.path.isfile(vs_path + "/index.faiss"):
221
+ vector_store = load_vector_store(vs_path, self.embeddings)
222
+ vector_store.add_documents(docs)
223
+ else:
224
+ vector_store = MyFAISS.from_documents(docs, self.embeddings) ##docs 为Document列表
225
+ torch_gc()
226
+ vector_store.save_local(vs_path)
227
+ return vs_path, [one_title]
228
+ except Exception as e:
229
+ logger.error(e)
230
+ return None, [one_title]
231
+
232
+ def get_knowledge_based_answer(self, query, vs_path, chat_history=[], streaming: bool = STREAMING):
233
+ vector_store = load_vector_store(vs_path, self.embeddings)
234
+ vector_store.chunk_size = self.chunk_size
235
+ vector_store.chunk_conent = self.chunk_conent
236
+ vector_store.score_threshold = self.score_threshold
237
+ related_docs_with_score = vector_store.similarity_search_with_score(query, k=self.top_k)
238
+ torch_gc()
239
+ if len(related_docs_with_score) > 0:
240
+ prompt = generate_prompt(related_docs_with_score, query)
241
+ else:
242
+ prompt = query
243
+
244
+ # 接入baichuan的代码分支:
245
+ if LLM_MODEL == "Baichuan-13B-Chat":
246
+ for answer_result in self.llm_model_chain._generate_answer(prompt=prompt, history=chat_history,
247
+ streaming=streaming):
248
+ resp = answer_result.llm_output["answer"]
249
+ history = answer_result.history
250
+ response = {"query": query,
251
+ "result": resp,
252
+ "source_documents": related_docs_with_score}
253
+ yield response, history
254
+ else: # 原本逻辑分支:
255
+ answer_result_stream_result = self.llm_model_chain(
256
+ {"prompt": prompt, "history": chat_history, "streaming": streaming})
257
+
258
+ for answer_result in answer_result_stream_result['answer_result_stream']:
259
+ resp = answer_result.llm_output["answer"]
260
+ history = answer_result.history
261
+ history[-1][0] = query
262
+ response = {"query": query,
263
+ "result": resp,
264
+ "source_documents": related_docs_with_score}
265
+ yield response, history
266
+
267
+ # query 查询内容
268
+ # vs_path 知识库路径
269
+ # chunk_conent 是否启用上下文关联
270
+ # score_threshold 搜索匹配score阈值
271
+ # vector_search_top_k 搜索知识库内容条数,默认搜索5条结果
272
+ # chunk_sizes 匹配单段内容的连接上下文长度
273
+ def get_knowledge_based_conent_test(self, query, vs_path, chunk_conent,
274
+ score_threshold=VECTOR_SEARCH_SCORE_THRESHOLD,
275
+ vector_search_top_k=VECTOR_SEARCH_TOP_K, chunk_size=CHUNK_SIZE):
276
+ vector_store = load_vector_store(vs_path, self.embeddings)
277
+ # FAISS.similarity_search_with_score_by_vector = similarity_search_with_score_by_vector
278
+ vector_store.chunk_conent = chunk_conent
279
+ vector_store.score_threshold = score_threshold
280
+ vector_store.chunk_size = chunk_size
281
+ related_docs_with_score = vector_store.similarity_search_with_score(query, k=vector_search_top_k)
282
+ if not related_docs_with_score:
283
+ response = {"query": query,
284
+ "source_documents": []}
285
+ return response, ""
286
+ torch_gc()
287
+ prompt = "\n".join([doc.page_content for doc in related_docs_with_score])
288
+ response = {"query": query,
289
+ "source_documents": related_docs_with_score}
290
+ return response, prompt
291
+
292
+ def get_search_result_based_answer(self, query, chat_history=[], streaming: bool = STREAMING):
293
+ results = bing_search(query)
294
+ result_docs = search_result2docs(results)
295
+ prompt = generate_prompt(result_docs, query)
296
+
297
+ answer_result_stream_result = self.llm_model_chain(
298
+ {"prompt": prompt, "history": chat_history, "streaming": streaming})
299
+
300
+ for answer_result in answer_result_stream_result['answer_result_stream']:
301
+ resp = answer_result.llm_output["answer"]
302
+ history = answer_result.history
303
+ history[-1][0] = query
304
+ response = {"query": query,
305
+ "result": resp,
306
+ "source_documents": result_docs}
307
+ yield response, history
308
+
309
+ def delete_file_from_vector_store(self,
310
+ filepath: str or List[str],
311
+ vs_path):
312
+ vector_store = load_vector_store(vs_path, self.embeddings)
313
+ status = vector_store.delete_doc(filepath)
314
+ return status
315
+
316
+ def update_file_from_vector_store(self,
317
+ filepath: str or List[str],
318
+ vs_path,
319
+ docs: List[Document], ):
320
+ vector_store = load_vector_store(vs_path, self.embeddings)
321
+ status = vector_store.update_doc(filepath, docs)
322
+ return status
323
+
324
+ def list_file_from_vector_store(self,
325
+ vs_path,
326
+ fullpath=False):
327
+ vector_store = load_vector_store(vs_path, self.embeddings)
328
+ docs = vector_store.list_docs()
329
+ if fullpath:
330
+ return docs
331
+ else:
332
+ return [os.path.split(doc)[-1] for doc in docs]
333
+
334
+
335
+ if __name__ == "__main__":
336
+ # 初始化消息
337
+ args = None
338
+ args = parser.parse_args(args=['--model-dir', '/media/checkpoint/', '--model', 'chatglm-6b', '--no-remote-model'])
339
+
340
+ args_dict = vars(args)
341
+ shared.loaderCheckPoint = LoaderCheckPoint(args_dict)
342
+ llm_model_ins = shared.loaderLLM()
343
+
344
+ local_doc_qa = LocalDocQA()
345
+ local_doc_qa.init_cfg(llm_model=llm_model_ins)
346
+ query = "本项目使用的embedding模型是什么,消耗多少显存"
347
+ vs_path = "/media/gpt4-pdf-chatbot-langchain/dev-langchain-ChatGLM/vector_store/test"
348
+ last_print_len = 0
349
+ # for resp, history in local_doc_qa.get_knowledge_based_answer(query=query,
350
+ # vs_path=vs_path,
351
+ # chat_history=[],
352
+ # streaming=True):
353
+ for resp, history in local_doc_qa.get_search_result_based_answer(query=query,
354
+ chat_history=[],
355
+ streaming=True):
356
+ print(resp["result"][last_print_len:], end="", flush=True)
357
+ last_print_len = len(resp["result"])
358
+ source_text = [f"""出处 [{inum + 1}] {doc.metadata['source'] if doc.metadata['source'].startswith("http")
359
+ else os.path.split(doc.metadata['source'])[-1]}:\n\n{doc.page_content}\n\n"""
360
+ # f"""相关度:{doc.metadata['score']}\n\n"""
361
+ for inum, doc in
362
+ enumerate(resp["source_documents"])]
363
+ logger.info("\n\n" + "\n\n".join(source_text))
364
+ pass
chains/text_load.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pinecone
3
+ from tqdm import tqdm
4
+ from langchain.llms import OpenAI
5
+ from langchain.text_splitter import SpacyTextSplitter
6
+ from langchain.document_loaders import TextLoader
7
+ from langchain.document_loaders import DirectoryLoader
8
+ from langchain.indexes import VectorstoreIndexCreator
9
+ from langchain.embeddings.openai import OpenAIEmbeddings
10
+ from langchain.vectorstores import Pinecone
11
+
12
+ #一些配置文件
13
+ openai_key="你的key" # 注册 openai.com 后获得
14
+ pinecone_key="你的key" # 注册 app.pinecone.io 后获得
15
+ pinecone_index="你的库" #app.pinecone.io 获得
16
+ pinecone_environment="你的Environment" # 登录pinecone后,在indexes页面 查看Environment
17
+ pinecone_namespace="你的Namespace" #如果不存在自动创建
18
+
19
+ #科学上网你懂得
20
+ os.environ['HTTP_PROXY'] = 'http://127.0.0.1:7890'
21
+ os.environ['HTTPS_PROXY'] = 'http://127.0.0.1:7890'
22
+
23
+ #初始化pinecone
24
+ pinecone.init(
25
+ api_key=pinecone_key,
26
+ environment=pinecone_environment
27
+ )
28
+ index = pinecone.Index(pinecone_index)
29
+
30
+ #初始化OpenAI的embeddings
31
+ embeddings = OpenAIEmbeddings(openai_api_key=openai_key)
32
+
33
+ #初始化text_splitter
34
+ text_splitter = SpacyTextSplitter(pipeline='zh_core_web_sm',chunk_size=1000,chunk_overlap=200)
35
+
36
+ # 读取目录下所有后缀是txt的文件
37
+ loader = DirectoryLoader('../docs', glob="**/*.txt", loader_cls=TextLoader)
38
+
39
+ #读取文本文件
40
+ documents = loader.load()
41
+
42
+ # 使用text_splitter对文档进行分割
43
+ split_text = text_splitter.split_documents(documents)
44
+ try:
45
+ for document in tqdm(split_text):
46
+ # 获取向量并储存到pinecone
47
+ Pinecone.from_documents([document], embeddings, index_name=pinecone_index)
48
+ except Exception as e:
49
+ print(f"Error: {e}")
50
+ quit()
51
+
52
+