Spaces:
Runtime error
Runtime error
#import gradio as gr | |
#import cv2 | |
#def to_black(image): | |
# output = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) | |
# return output | |
#interface = gr.Interface(fn=to_black, inputs="image", outputs="image") | |
#print('here') | |
#interface.launch() | |
#print(share_url) | |
#print(local_url) | |
#print(app) | |
#interface.launch(inbrowser =True, share=True, port=8888) | |
#url = interface.share() | |
#print(url) | |
from langchain.chains import RetrievalQA | |
from langchain.document_loaders import UnstructuredFileLoader, TextLoader, CSVLoader | |
from langchain.document_loaders import CSVLoader | |
from langchain.document_loaders import TextLoader | |
from langchain.vectorstores import DocArrayInMemorySearch | |
from langchain.indexes import VectorstoreIndexCreator | |
from langchain.prompts import PromptTemplate | |
from langchain.embeddings.huggingface import HuggingFaceEmbeddings | |
from langchain import HuggingFacePipeline | |
import torch | |
from langchain.vectorstores import FAISS | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain.chains.base import Chain | |
from langchain.chains import ConversationalRetrievalChain | |
from langchain.chains.summarize import load_summarize_chain | |
import gradio as gr | |
from typing import List | |
from tqdm import tqdm | |
import logging | |
import argparse | |
import os | |
import string | |
CHUNK_SIZE=600 | |
CHUNK_OVERLAP = 100 | |
SEARCH_TOP_K = 5 | |
logger = logging.getLogger("bio_LLM_logger") | |
def tree(filepath, ignore_dir_names=None, ignore_file_names=None): | |
"""返回两个列表,第一个列表为 filepath 下全部文件的完整路径, 第二个为对应的文件名""" | |
if ignore_dir_names is None: | |
ignore_dir_names = [] | |
if ignore_file_names is None: | |
ignore_file_names = [] | |
ret_list = [] | |
if isinstance(filepath, str): | |
if not os.path.exists(filepath): | |
print("路径不存在") | |
return None, None | |
elif os.path.isfile(filepath) and os.path.basename(filepath) not in ignore_file_names: | |
return [filepath], [os.path.basename(filepath)] | |
elif os.path.isdir(filepath) and os.path.basename(filepath) not in ignore_dir_names: | |
for file in os.listdir(filepath): | |
fullfilepath = os.path.join(filepath, file) | |
if os.path.isfile(fullfilepath) and os.path.basename(fullfilepath) not in ignore_file_names: | |
ret_list.append(fullfilepath) | |
if os.path.isdir(fullfilepath) and os.path.basename(fullfilepath) not in ignore_dir_names: | |
ret_list.extend(tree(fullfilepath, ignore_dir_names, ignore_file_names)[0]) | |
return ret_list, [os.path.basename(p) for p in ret_list] | |
def load_file(file_path, chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP): | |
if file_path.lower().endswith(".pdf"): | |
loader = UnstructuredFileLoader(file_path, mode="elements") | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap= chunk_overlap) | |
docs = loader.load_and_split(text_splitter=text_splitter) | |
elif file_path.lower().endswith(".txt"): | |
loader = TextLoader(file_path, autodetect_encoding=True) | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap= chunk_overlap) | |
docs = loader.load_and_split(text_splitter=text_splitter) | |
elif file_path.lower().endswith(".csv"): | |
loader = CSVLoader(file_path) | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap= chunk_overlap) | |
docs = loader.load_and_split(text_splitter=text_splitter) | |
else: | |
print("unsupported the file format") | |
return docs | |
#class summary_chain: | |
# def init_cfg(self, | |
# llm_model: Chain, | |
def summary(model, chain_type, PROMPT, REFINE_PROMPT,docs): | |
if chain_type == "stuff": | |
chain = load_summarize_chain(model, chain_type="stuff", prompt=PROMPT) | |
elif chain_type == "refine": | |
chain = load_summarize_chain(model, chain_type="refine", question_prompt=PROMPT, refine_prompt=REFINE_PROMPT) | |
print(chain.run(docs)) | |
class QA_Localdb: | |
llm_model_chain: Chain = None | |
embeddings: object = None | |
top_k: int = SEARCH_TOP_K | |
chunk_size: int = CHUNK_SIZE | |
def init_cfg(self, | |
llm_model: Chain, | |
embedding_model: str, | |
#embedding_device: str, | |
top_k = SEARCH_TOP_K, | |
): | |
self.llm_model_chain = llm_model | |
self.embeddings = HuggingFaceEmbeddings(model_name = embedding_model) | |
self.top_k = top_k | |
def init_knowledge_vector_store(self, | |
file_path: str or List[str], | |
vectorstore_path: str or os.PathLike = None, | |
): | |
loaded_files = [] | |
failed_files = [] | |
if isinstance(file_path, str): | |
if not os.path.exists(file_path): | |
print("unknown path") | |
return None | |
elif os.path.isfile(file_path): | |
file = os.path.split(file_path)[-1] | |
try: | |
docs = load_file(file_path) | |
logger.info(f"{file} sucessful loaded") | |
loaded_files.append(file_path) | |
except Exception as e: | |
logger.error(e) | |
logger.info(f"{file} unsucessful loaded") | |
return None | |
elif os.path.isdir(file_path): | |
docs=[] | |
for fullfilepath, file in tqdm(zip(*tree(file_path, ignore_dir_names=['tmp_files'])), desc="load file"): | |
try: | |
docs += load_file(fullfilepath) | |
loaded_files.append(fullfilepath) | |
except Exception as e: | |
logger.error(e) | |
failed_files.append(file) | |
if len(failed_files) > 0: | |
logger.info('unloaded files are as follows') | |
for file in failed_files: | |
logger.info(f"{file}\n") | |
else: | |
docs = [] | |
for file in file_path: | |
try: | |
docs += load_file(file) | |
logger.info(f"{file} sucessful loaded") | |
loaded_files.append(file) | |
except Exception as e: | |
logger.error(e) | |
logger.info(f"{file} unsucessful loaded") | |
if len(docs) > 0: | |
logger.info("sucessful loaded, generating vector store") | |
if vectorstore_path and os.path.isdir(vectorstore_path) and "index.faiss" in os.listdir(vectorstore_path): | |
print("temp") | |
# vector_store = load_vector_store(vectorstore_path, self.embeddings) | |
# vector_store.add_documents(docs) | |
# torch_gc() | |
else: | |
if not vectorstore_path: | |
vectorstore_path = "" | |
vector_store = FAISS.from_documents(docs, self.embeddings) | |
#vector_store.save_local(vectorstore_path) | |
return vector_store, loaded_files | |
else: | |
logger.info("file load failed") | |
''' | |
def delete_file_from_vector_store(self, | |
filepath: str or List[str], | |
vs_path): | |
vector_store = load_vector_store(vs_path, self.embeddings) | |
status = vector_store.delete_doc(filepath) | |
return status | |
def update_file_from_vector_store(self, | |
filepath: str or List[str], | |
vs_path, | |
docs: List[Document], ): | |
vector_store = load_vector_store(vs_path, self.embeddings) | |
status = vector_store.update_doc(filepath, docs) | |
return status | |
def list_file_from_vector_store(self, | |
vs_path, | |
fullpath=False): | |
vector_store = load_vector_store(vs_path, self.embeddings) | |
docs = vector_store.list_docs() | |
if fullpath: | |
return docs | |
else: | |
return [os.path.split(doc)[-1] for doc in docs] | |
''' | |
def QA_model(): | |
# file_path = "/mnt/petrelfs/lvying/LLM/BoMA/data/test/OPUS-DSD.pdf" | |
file_path = "doc1.txt" | |
# file_path = "/mnt/petrelfs/lvying/LLM/BoMA/data/test/Interageting-Prior-into-DA.pdf" | |
# file_path = "/mnt/petrelfs/lvying/LLM/BoMA/data/test/" | |
model_path = "/mnt/petrelfs/lvying/LLM/BoMA/models/LLM/Llama-2-13b-chat-hf" | |
embedding_path = "/mnt/petrelfs/lvying/LLM/BoMA/text2vec/instructor-xl/" | |
model = HuggingFacePipeline.from_model_id(model_id="daryl149/llama-2-7b-chat-hf", | |
task="text-generation", | |
model_kwargs={ | |
"torch_dtype" : torch.float32, | |
"low_cpu_mem_usage" :True, | |
"temperature": 0.2, | |
"max_length": 2048, | |
# "device_map": "auto", | |
"repetition_penalty":1.1} | |
) | |
print(model.model_id) | |
QA = QA_Localdb() | |
QA.init_cfg(llm_model=model, embedding_model = "sentence-transformers/paraphrase-MiniLM-L6-v2") | |
vector_store, _ =QA.init_knowledge_vector_store(file_path) | |
retriever = vector_store.as_retriever(search_kwargs={"k": 3}) | |
print("loading LLM...") | |
prompt_template = ("Below is an instruction that describes a task. " | |
"Write a response that appropriately completes the request.\n\n" | |
"### Instruction:\n{context}\n{question}\n\n### Response: ") | |
PROMPT = PromptTemplate( | |
template=prompt_template, input_variables=["context", "question"] | |
) | |
chain_type_kwargs = {"prompt": PROMPT} | |
#print(chain_type_kwargs) | |
''' | |
qa_stuff = RetrievalQA.from_chain_type( | |
llm = model, | |
chain_type="stuff", | |
retriever = retriever, | |
chain_type_kwargs = chain_type_kwargs, | |
# verbose = True | |
) | |
while True: | |
print("Input Qusetion:") | |
query = input() | |
if len(query.strip())==0: | |
break | |
print(qa_stuff.run(query)) | |
''' | |
''' | |
qa = ConversationalRetrievalChain.from_llm( | |
llm = QA.llm_model_chain, | |
chain_type="stuff", | |
retriever = retriever, | |
combine_docs_chain_kwargs = chain_type_kwargs, | |
# verbose = True | |
) | |
''' | |
qa = RetrievalQA.from_chain_type( | |
llm = QA.llm_model_chain, | |
chain_type="stuff", | |
retriever = retriever, | |
chain_type_kwargs = chain_type_kwargs, | |
# verbose = True | |
) | |
return qa | |
qa_temp = QA_model() | |
def temp(query): | |
return qa_temp.run(query) | |
def answer_question(query): | |
print(query) | |
chat_history = [] | |
threshold_history = 10 # Remembered historical conversations | |
i = 0 | |
if i>threshold_history: | |
chat_history = [] | |
print("Send a Message:") | |
#query = context | |
#if len(query.strip())==0: | |
# break | |
result = qa_temp({"question":query, "chat_history": chat_history}) | |
print(type(result["answer"])) | |
chat_history.append((query, result["answer"])) | |
i = i + 1 | |
resp = result["answer"] | |
return str(resp) | |
iface = gr.Interface( | |
fn = temp, | |
inputs="text", | |
outputs="text",) | |
#title="问答界面", | |
#description="输入问题和相关文本,得到问题的答案。", | |
#article="这里是相关的文本。可以输入一些段落或者问题的背景。", | |
#examples=[ | |
# ["Gradio是什么?", "Gradio是一个用于构建和部署机器学习模型的开源库。"], | |
# ["Python的创始人是谁?", "Python的创始人是Guido van Rossum。"] | |
#]) | |
#print(iface.launch(share=True)) | |
#print("======Finish======") | |
#share_url = iface.share() | |
#print(share_url) | |
iface.launch() | |