Spaces:
Runtime error
Runtime error
import gradio as gr | |
from langchain.embeddings.openai import OpenAIEmbeddings | |
from langchain.text_splitter import CharacterTextSplitter, RecursiveCharacterTextSplitter | |
from langchain.vectorstores import DocArrayInMemorySearch | |
from langchain.chains import RetrievalQA, ConversationalRetrievalChain | |
from langchain.memory import ConversationBufferMemory | |
from langchain.chat_models import ChatOpenAI | |
from langchain.embeddings import HuggingFaceEmbeddings | |
from langchain import HuggingFaceHub | |
from langchain.llms import LlamaCpp | |
from huggingface_hub import hf_hub_download | |
import param | |
import os | |
import torch | |
from langchain.document_loaders import ( | |
EverNoteLoader, | |
TextLoader, | |
UnstructuredEPubLoader, | |
UnstructuredHTMLLoader, | |
UnstructuredMarkdownLoader, | |
UnstructuredODTLoader, | |
UnstructuredPowerPointLoader, | |
UnstructuredWordDocumentLoader, | |
PyPDFLoader, | |
) | |
#YOUR_HF_TOKEN = os.getenv("My_hf_token") | |
llm_api=HuggingFaceHub( | |
huggingfacehub_api_token=os.getenv("My_hf_token"), | |
repo_id="tiiuae/falcon-7b-instruct", | |
model_kwargs={ | |
"temperature":0.2, | |
"max_new_tokens":500, | |
"top_k":50, | |
"top_p":0.95, | |
"repetition_penalty":1.2, | |
},), #ChatOpenAI(model_name=llm_name, temperature=0) | |
#alter | |
def load_db(files, chain_type, k, llm): | |
EXTENSIONS = { | |
".txt": (TextLoader, {"encoding": "utf8"}), | |
".pdf": (PyPDFLoader, {}), | |
".doc": (UnstructuredWordDocumentLoader, {}), | |
".docx": (UnstructuredWordDocumentLoader, {}), | |
".enex": (EverNoteLoader, {}), | |
".epub": (UnstructuredEPubLoader, {}), | |
".html": (UnstructuredHTMLLoader, {}), | |
".md": (UnstructuredMarkdownLoader, {}), | |
".odt": (UnstructuredODTLoader, {}), | |
".ppt": (UnstructuredPowerPointLoader, {}), | |
".pptx": (UnstructuredPowerPointLoader, {}), | |
} | |
# select extensions loader | |
documents = [] | |
for file in files: | |
ext = "." + file.rsplit(".", 1)[-1] | |
if ext in EXTENSIONS: | |
loader_class, loader_args = EXTENSIONS[ext] | |
loader = loader_class(file, **loader_args) | |
documents.extend(loader.load()) | |
else: | |
pass | |
# load documents | |
if documents == []: | |
loader_class, loader_args = EXTENSIONS['.txt'] | |
loader = loader_class('demo_docs/demo.txt', **loader_args) | |
documents = loader.load() | |
# split documents | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=150) | |
docs = text_splitter.split_documents(documents) | |
# define embedding | |
embeddings = HuggingFaceEmbeddings(model_name='all-MiniLM-L6-v2') # all-mpnet-base-v2 #embeddings = OpenAIEmbeddings() | |
# create vector database from data | |
db = DocArrayInMemorySearch.from_documents(docs, embeddings) | |
# define retriever | |
retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": k}) | |
# create a chatbot chain. Memory is managed externally. | |
qa = ConversationalRetrievalChain.from_llm( | |
llm=llm, | |
chain_type=chain_type, | |
retriever=retriever, | |
return_source_documents=True, | |
return_generated_question=True, | |
) | |
return qa | |
class DocChat(param.Parameterized): | |
chat_history = param.List([]) | |
answer = param.String("") | |
db_query = param.String("") | |
db_response = param.List([]) | |
llm = llm_api[0] | |
k_value = param.Integer(3) | |
def __init__(self, **params): | |
super(DocChat, self).__init__( **params) | |
self.loaded_file = "demo_docs/demo.txt" | |
self.qa = load_db(self.loaded_file,"stuff", self.k_value, self.llm) | |
def call_load_db(self, path_file, k): | |
if not os.path.exists(path_file[0]): # init or no file specified | |
return "No file loaded" | |
else: | |
try: | |
self.qa = load_db(path_file, "stuff", k, self.llm) | |
self.loaded_file = path_file | |
except: | |
return f'No valid file' | |
self.clr_history() | |
return f"New DB created | Loaded File: {self.loaded_file}" | |
# chat | |
def convchain(self, query, k_max): | |
if k_max != self.k_value: | |
print("Maximum querys changed, reloading DB") | |
self.qa = load_db(self.loaded_file,"stuff", k_max, self.llm) | |
self.k_value = k_max | |
result = self.qa({"question": query, "chat_history": self.chat_history}) | |
self.chat_history.extend([(query, result["answer"])]) | |
self.db_query = result["generated_question"] | |
self.db_response = result["source_documents"] | |
self.answer = result['answer'] | |
return self.answer | |
def change_llm(self, repo_, file_, max_tokens=16, temperature=0.2, top_p=0.95, top_k=50, repeat_penalty=1.2, k=3): | |
if torch.cuda.is_available(): | |
try: | |
model_path = hf_hub_download(repo_id=repo_, filename=file_) | |
self.llm = LlamaCpp( | |
model_path=model_path, | |
n_ctx=1000, | |
n_batch=512, | |
n_gpu_layers=35, | |
max_tokens=max_tokens, | |
verbose=False, | |
temperature=temperature, | |
top_p=top_p, | |
top_k=top_k, | |
repeat_penalty=repeat_penalty, | |
) | |
self.qa = load_db(self.loaded_file,"stuff", k, self.llm) | |
self.k_value = k | |
return f"Loaded {file_}" | |
except: | |
return "No valid model" | |
else: | |
try: | |
model_path = hf_hub_download(repo_id=repo_, filename=file_) | |
self.llm = LlamaCpp( | |
model_path=model_path, | |
n_ctx=1000, | |
n_batch=8, | |
max_tokens=max_tokens, | |
verbose=False, | |
temperature=temperature, | |
top_p=top_p, | |
top_k=top_k, | |
repeat_penalty=repeat_penalty, | |
) | |
self.qa = load_db(self.loaded_file,"stuff", k, self.llm) | |
self.k_value = k | |
return f"Loaded {file_}" | |
except: | |
return "No valid model" | |
def default_falcon_model(self): | |
self.llm = llm_api[0] | |
self.qa = load_db(self.loaded_file,"stuff", self.k_value, self.llm) | |
return "Loaded model Falcon 7B-instruct" | |
def get_lquest(self): | |
if not self.db_query : | |
return print("Last question to DB: no DB accesses so far") | |
return self.db_query | |
def get_sources(self): | |
if not self.db_response: | |
return | |
#rlist=[f"Result of DB lookup:"] | |
rlist=[] | |
for doc in self.db_response: | |
for element in doc: | |
rlist.append(element) | |
return rlist | |
def get_chats(self): | |
if not self.chat_history: | |
return "No History Yet" | |
#rlist=[f"Current Chat History variable"] | |
rlist=[] | |
for exchange in self.chat_history: | |
rlist.append(exchange) | |
return rlist | |
def clr_history(self,count=0): | |
self.chat_history = [] | |
return "HISTORY CLEARED" | |