Spaces:
Runtime error
Runtime error
Add application file
Browse files
app.py
ADDED
@@ -0,0 +1,318 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#import gradio as gr
|
2 |
+
#import cv2
|
3 |
+
#def to_black(image):
|
4 |
+
# output = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
5 |
+
# return output
|
6 |
+
#interface = gr.Interface(fn=to_black, inputs="image", outputs="image")
|
7 |
+
#print('here')
|
8 |
+
#interface.launch()
|
9 |
+
|
10 |
+
#print(share_url)
|
11 |
+
#print(local_url)
|
12 |
+
#print(app)
|
13 |
+
#interface.launch(inbrowser =True, share=True, port=8888)
|
14 |
+
#url = interface.share()
|
15 |
+
#print(url)
|
16 |
+
from langchain.chains import RetrievalQA
|
17 |
+
from langchain.document_loaders import UnstructuredFileLoader, TextLoader, CSVLoader
|
18 |
+
from langchain.document_loaders import CSVLoader
|
19 |
+
from langchain.document_loaders import TextLoader
|
20 |
+
from langchain.vectorstores import DocArrayInMemorySearch
|
21 |
+
from langchain.indexes import VectorstoreIndexCreator
|
22 |
+
from langchain.prompts import PromptTemplate
|
23 |
+
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
24 |
+
from langchain import HuggingFacePipeline
|
25 |
+
import torch
|
26 |
+
from langchain.vectorstores import FAISS
|
27 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
28 |
+
from langchain.chains.base import Chain
|
29 |
+
from langchain.chains import ConversationalRetrievalChain
|
30 |
+
from langchain.chains.summarize import load_summarize_chain
|
31 |
+
import gradio as gr
|
32 |
+
from typing import List
|
33 |
+
from tqdm import tqdm
|
34 |
+
import logging
|
35 |
+
import argparse
|
36 |
+
import os
|
37 |
+
import string
|
38 |
+
|
39 |
+
CHUNK_SIZE=600
|
40 |
+
CHUNK_OVERLAP = 100
|
41 |
+
SEARCH_TOP_K = 5
|
42 |
+
logger = logging.getLogger("bio_LLM_logger")
|
43 |
+
|
44 |
+
def tree(filepath, ignore_dir_names=None, ignore_file_names=None):
|
45 |
+
"""返回两个列表,第一个列表为 filepath 下全部文件的完整路径, 第二个为对应的文件名"""
|
46 |
+
if ignore_dir_names is None:
|
47 |
+
ignore_dir_names = []
|
48 |
+
if ignore_file_names is None:
|
49 |
+
ignore_file_names = []
|
50 |
+
ret_list = []
|
51 |
+
if isinstance(filepath, str):
|
52 |
+
if not os.path.exists(filepath):
|
53 |
+
print("路径不存在")
|
54 |
+
return None, None
|
55 |
+
elif os.path.isfile(filepath) and os.path.basename(filepath) not in ignore_file_names:
|
56 |
+
return [filepath], [os.path.basename(filepath)]
|
57 |
+
elif os.path.isdir(filepath) and os.path.basename(filepath) not in ignore_dir_names:
|
58 |
+
for file in os.listdir(filepath):
|
59 |
+
fullfilepath = os.path.join(filepath, file)
|
60 |
+
if os.path.isfile(fullfilepath) and os.path.basename(fullfilepath) not in ignore_file_names:
|
61 |
+
ret_list.append(fullfilepath)
|
62 |
+
if os.path.isdir(fullfilepath) and os.path.basename(fullfilepath) not in ignore_dir_names:
|
63 |
+
ret_list.extend(tree(fullfilepath, ignore_dir_names, ignore_file_names)[0])
|
64 |
+
return ret_list, [os.path.basename(p) for p in ret_list]
|
65 |
+
|
66 |
+
|
67 |
+
def load_file(file_path, chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP):
|
68 |
+
if file_path.lower().endswith(".pdf"):
|
69 |
+
loader = UnstructuredFileLoader(file_path, mode="elements")
|
70 |
+
text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap= chunk_overlap)
|
71 |
+
docs = loader.load_and_split(text_splitter=text_splitter)
|
72 |
+
elif file_path.lower().endswith(".txt"):
|
73 |
+
loader = TextLoader(file_path, autodetect_encoding=True)
|
74 |
+
text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap= chunk_overlap)
|
75 |
+
docs = loader.load_and_split(text_splitter=text_splitter)
|
76 |
+
elif file_path.lower().endswith(".csv"):
|
77 |
+
loader = CSVLoader(file_path)
|
78 |
+
text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap= chunk_overlap)
|
79 |
+
docs = loader.load_and_split(text_splitter=text_splitter)
|
80 |
+
else:
|
81 |
+
print("unsupported the file format")
|
82 |
+
|
83 |
+
return docs
|
84 |
+
|
85 |
+
#class summary_chain:
|
86 |
+
# def init_cfg(self,
|
87 |
+
# llm_model: Chain,
|
88 |
+
|
89 |
+
def summary(model, chain_type, PROMPT, REFINE_PROMPT,docs):
|
90 |
+
if chain_type == "stuff":
|
91 |
+
chain = load_summarize_chain(model, chain_type="stuff", prompt=PROMPT)
|
92 |
+
elif chain_type == "refine":
|
93 |
+
chain = load_summarize_chain(model, chain_type="refine", question_prompt=PROMPT, refine_prompt=REFINE_PROMPT)
|
94 |
+
print(chain.run(docs))
|
95 |
+
|
96 |
+
|
97 |
+
class QA_Localdb:
|
98 |
+
llm_model_chain: Chain = None
|
99 |
+
embeddings: object = None
|
100 |
+
top_k: int = SEARCH_TOP_K
|
101 |
+
chunk_size: int = CHUNK_SIZE
|
102 |
+
|
103 |
+
def init_cfg(self,
|
104 |
+
llm_model: Chain,
|
105 |
+
embedding_model: str,
|
106 |
+
#embedding_device: str,
|
107 |
+
top_k = SEARCH_TOP_K,
|
108 |
+
):
|
109 |
+
self.llm_model_chain = llm_model
|
110 |
+
self.embeddings = HuggingFaceEmbeddings(model_name = embedding_model)
|
111 |
+
self.top_k = top_k
|
112 |
+
|
113 |
+
def init_knowledge_vector_store(self,
|
114 |
+
file_path: str or List[str],
|
115 |
+
vectorstore_path: str or os.PathLike = None,
|
116 |
+
):
|
117 |
+
loaded_files = []
|
118 |
+
failed_files = []
|
119 |
+
if isinstance(file_path, str):
|
120 |
+
if not os.path.exists(file_path):
|
121 |
+
print("unknown path")
|
122 |
+
return None
|
123 |
+
elif os.path.isfile(file_path):
|
124 |
+
file = os.path.split(file_path)[-1]
|
125 |
+
try:
|
126 |
+
docs = load_file(file_path)
|
127 |
+
logger.info(f"{file} sucessful loaded")
|
128 |
+
loaded_files.append(file_path)
|
129 |
+
except Exception as e:
|
130 |
+
logger.error(e)
|
131 |
+
logger.info(f"{file} unsucessful loaded")
|
132 |
+
return None
|
133 |
+
|
134 |
+
elif os.path.isdir(file_path):
|
135 |
+
docs=[]
|
136 |
+
for fullfilepath, file in tqdm(zip(*tree(file_path, ignore_dir_names=['tmp_files'])), desc="load file"):
|
137 |
+
try:
|
138 |
+
docs += load_file(fullfilepath)
|
139 |
+
loaded_files.append(fullfilepath)
|
140 |
+
except Exception as e:
|
141 |
+
logger.error(e)
|
142 |
+
failed_files.append(file)
|
143 |
+
|
144 |
+
if len(failed_files) > 0:
|
145 |
+
logger.info('unloaded files are as follows')
|
146 |
+
for file in failed_files:
|
147 |
+
logger.info(f"{file}\n")
|
148 |
+
else:
|
149 |
+
docs = []
|
150 |
+
for file in file_path:
|
151 |
+
try:
|
152 |
+
docs += load_file(file)
|
153 |
+
logger.info(f"{file} sucessful loaded")
|
154 |
+
loaded_files.append(file)
|
155 |
+
except Exception as e:
|
156 |
+
logger.error(e)
|
157 |
+
logger.info(f"{file} unsucessful loaded")
|
158 |
+
if len(docs) > 0:
|
159 |
+
logger.info("sucessful loaded, generating vector store")
|
160 |
+
if vectorstore_path and os.path.isdir(vectorstore_path) and "index.faiss" in os.listdir(vectorstore_path):
|
161 |
+
print("temp")
|
162 |
+
|
163 |
+
# vector_store = load_vector_store(vectorstore_path, self.embeddings)
|
164 |
+
# vector_store.add_documents(docs)
|
165 |
+
# torch_gc()
|
166 |
+
else:
|
167 |
+
if not vectorstore_path:
|
168 |
+
vectorstore_path = ""
|
169 |
+
vector_store = FAISS.from_documents(docs, self.embeddings)
|
170 |
+
#vector_store.save_local(vectorstore_path)
|
171 |
+
return vector_store, loaded_files
|
172 |
+
else:
|
173 |
+
logger.info("file load failed")
|
174 |
+
|
175 |
+
|
176 |
+
|
177 |
+
'''
|
178 |
+
def delete_file_from_vector_store(self,
|
179 |
+
filepath: str or List[str],
|
180 |
+
vs_path):
|
181 |
+
vector_store = load_vector_store(vs_path, self.embeddings)
|
182 |
+
status = vector_store.delete_doc(filepath)
|
183 |
+
return status
|
184 |
+
|
185 |
+
def update_file_from_vector_store(self,
|
186 |
+
filepath: str or List[str],
|
187 |
+
vs_path,
|
188 |
+
docs: List[Document], ):
|
189 |
+
vector_store = load_vector_store(vs_path, self.embeddings)
|
190 |
+
status = vector_store.update_doc(filepath, docs)
|
191 |
+
return status
|
192 |
+
|
193 |
+
def list_file_from_vector_store(self,
|
194 |
+
vs_path,
|
195 |
+
fullpath=False):
|
196 |
+
vector_store = load_vector_store(vs_path, self.embeddings)
|
197 |
+
docs = vector_store.list_docs()
|
198 |
+
if fullpath:
|
199 |
+
return docs
|
200 |
+
else:
|
201 |
+
return [os.path.split(doc)[-1] for doc in docs]
|
202 |
+
'''
|
203 |
+
def QA_model():
|
204 |
+
# file_path = "/mnt/petrelfs/lvying/LLM/BoMA/data/test/OPUS-DSD.pdf"
|
205 |
+
file_path = "OPUS-BioLLM-v1/data/test/Interageting-Prior-into-DA.pdf"
|
206 |
+
# file_path = "/mnt/petrelfs/lvying/LLM/BoMA/data/test/Interageting-Prior-into-DA.pdf"
|
207 |
+
# file_path = "/mnt/petrelfs/lvying/LLM/BoMA/data/test/"
|
208 |
+
|
209 |
+
model_path = "/mnt/petrelfs/lvying/LLM/BoMA/models/LLM/Llama-2-13b-chat-hf"
|
210 |
+
embedding_path = "/mnt/petrelfs/lvying/LLM/BoMA/text2vec/instructor-xl/"
|
211 |
+
|
212 |
+
model = HuggingFacePipeline.from_model_id(model_id="daryl149/llama-2-7b-chat-hf",
|
213 |
+
task="text-generation",
|
214 |
+
model_kwargs={
|
215 |
+
"torch_dtype" : torch.float16,
|
216 |
+
"low_cpu_mem_usage" : True,
|
217 |
+
"temperature": 0.2,
|
218 |
+
"max_length": 2048,
|
219 |
+
#"device_map": "auto",
|
220 |
+
"repetition_penalty":1.1}
|
221 |
+
)
|
222 |
+
print(model.model_id)
|
223 |
+
QA = QA_Localdb()
|
224 |
+
QA.init_cfg(llm_model=model, embedding_model = "sentence-transformers/paraphrase-MiniLM-L6-v2")
|
225 |
+
|
226 |
+
vector_store, _ =QA.init_knowledge_vector_store(file_path)
|
227 |
+
retriever = vector_store.as_retriever(search_kwargs={"k": 3})
|
228 |
+
|
229 |
+
print("loading LLM...")
|
230 |
+
prompt_template = ("Below is an instruction that describes a task. "
|
231 |
+
"Write a response that appropriately completes the request.\n\n"
|
232 |
+
"### Instruction:\n{context}\n{question}\n\n### Response: ")
|
233 |
+
|
234 |
+
PROMPT = PromptTemplate(
|
235 |
+
template=prompt_template, input_variables=["context", "question"]
|
236 |
+
)
|
237 |
+
|
238 |
+
chain_type_kwargs = {"prompt": PROMPT}
|
239 |
+
|
240 |
+
#print(chain_type_kwargs)
|
241 |
+
'''
|
242 |
+
qa_stuff = RetrievalQA.from_chain_type(
|
243 |
+
llm = model,
|
244 |
+
chain_type="stuff",
|
245 |
+
retriever = retriever,
|
246 |
+
chain_type_kwargs = chain_type_kwargs,
|
247 |
+
# verbose = True
|
248 |
+
)
|
249 |
+
while True:
|
250 |
+
print("Input Qusetion:")
|
251 |
+
query = input()
|
252 |
+
if len(query.strip())==0:
|
253 |
+
break
|
254 |
+
print(qa_stuff.run(query))
|
255 |
+
|
256 |
+
'''
|
257 |
+
'''
|
258 |
+
qa = ConversationalRetrievalChain.from_llm(
|
259 |
+
llm = QA.llm_model_chain,
|
260 |
+
chain_type="stuff",
|
261 |
+
retriever = retriever,
|
262 |
+
combine_docs_chain_kwargs = chain_type_kwargs,
|
263 |
+
# verbose = True
|
264 |
+
)
|
265 |
+
'''
|
266 |
+
qa = RetrievalQA.from_chain_type(
|
267 |
+
llm = QA.llm_model_chain,
|
268 |
+
chain_type="stuff",
|
269 |
+
retriever = retriever,
|
270 |
+
chain_type_kwargs = chain_type_kwargs,
|
271 |
+
# verbose = True
|
272 |
+
)
|
273 |
+
return qa
|
274 |
+
qa_temp = QA_model()
|
275 |
+
|
276 |
+
def temp(query):
|
277 |
+
return qa_temp.run(query)
|
278 |
+
|
279 |
+
|
280 |
+
|
281 |
+
def answer_question(query):
|
282 |
+
print(query)
|
283 |
+
chat_history = []
|
284 |
+
threshold_history = 10 # Remembered historical conversations
|
285 |
+
i = 0
|
286 |
+
if i>threshold_history:
|
287 |
+
chat_history = []
|
288 |
+
print("Send a Message:")
|
289 |
+
#query = context
|
290 |
+
#if len(query.strip())==0:
|
291 |
+
# break
|
292 |
+
result = qa_temp({"question":query, "chat_history": chat_history})
|
293 |
+
print(type(result["answer"]))
|
294 |
+
chat_history.append((query, result["answer"]))
|
295 |
+
i = i + 1
|
296 |
+
resp = result["answer"]
|
297 |
+
return str(resp)
|
298 |
+
|
299 |
+
|
300 |
+
|
301 |
+
|
302 |
+
iface = gr.Interface(
|
303 |
+
fn = temp,
|
304 |
+
inputs="text",
|
305 |
+
outputs="text",)
|
306 |
+
#title="问答界面",
|
307 |
+
#description="输入问题和相关文本,得到问题的答案。",
|
308 |
+
#article="这里是相关的文本。可以输入一些段落或者问题的背景。",
|
309 |
+
#examples=[
|
310 |
+
# ["Gradio是什么?", "Gradio是一个用于构建和部署机器学习模型的开源库。"],
|
311 |
+
# ["Python的创始人是谁?", "Python的创始人是Guido van Rossum。"]
|
312 |
+
#])
|
313 |
+
#print(iface.launch(share=True))
|
314 |
+
|
315 |
+
#print("======Finish======")
|
316 |
+
#share_url = iface.share()
|
317 |
+
#print(share_url)
|
318 |
+
iface.launch()
|