Spaces:
Sleeping
Sleeping
import os | |
import streamlit as st | |
import re | |
from tempfile import NamedTemporaryFile | |
import time | |
import pathlib | |
#from PyPDF2 import PdfReader | |
from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline | |
from langchain_community.llms import LlamaCpp | |
from langchain.prompts import PromptTemplate | |
from langchain.chains import LLMChain | |
from langchain.callbacks.manager import CallbackManager | |
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler | |
from langchain_community.embeddings import HuggingFaceEmbeddings | |
from langchain.chains import RetrievalQA | |
from langchain_community.vectorstores import FAISS | |
from langchain.chains.question_answering import load_qa_chain | |
from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_PROMPT | |
from langchain_community.document_loaders import TextLoader | |
from langchain_community.document_loaders import PyPDFLoader | |
from langchain.memory import ConversationBufferWindowMemory | |
from langchain.memory import ConversationBufferMemory | |
from langchain.chains import ConversationalRetrievalChain | |
from langchain.memory.chat_message_histories.streamlit import StreamlitChatMessageHistory | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain_community.llms import HuggingFaceHub | |
SECRET_TOKEN = os.getenv("HF_TOKEN") | |
os.environ["HUGGINGFACEHUB_API_TOKEN"] = SECRET_TOKEN | |
# sidebar contents | |
with st.sidebar: | |
st.title('DOC-QA DEMO ') | |
st.markdown(''' | |
## About | |
Detail this application: | |
- LLM model: Phi-2-4bit | |
- Hardware resource : Huggingface space 8 vCPU 32 GB | |
''') | |
def split_docs(documents,chunk_size=1000): | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=200) | |
sp_docs = text_splitter.split_documents(documents) | |
return sp_docs | |
def load_llama2_llamaCpp(): | |
core_model_name = "phi-2.Q4_K_M.gguf" | |
#n_gpu_layers = 32 | |
n_batch = 512 | |
callback_manager = CallbackManager([StreamingStdOutCallbackHandler()]) | |
llm = LlamaCpp( | |
model_path=core_model_name, | |
#n_gpu_layers=n_gpu_layers, | |
n_batch=n_batch, | |
callback_manager=callback_manager, | |
verbose=True,n_ctx = 4096, temperature = 0.1, max_tokens = 128 | |
) | |
return llm | |
def set_custom_prompt(): | |
custom_prompt_template = """ Use the following pieces of information from context to answer the user's question. | |
If you don't know the answer, don't try to make up an answer. | |
Context : {context} | |
Question : {question} | |
Please answer the questions in a concise and straightforward manner. | |
Helpful answer: | |
""" | |
prompt = PromptTemplate(template=custom_prompt_template, input_variables=['context', | |
'question', | |
]) | |
return prompt | |
def load_embeddings(): | |
embeddings = HuggingFaceEmbeddings(model_name = "thenlper/gte-base", | |
model_kwargs = {'device': 'cpu'}) | |
return embeddings | |
def main(): | |
data = [] | |
sp_docs_list = [] | |
msgs = StreamlitChatMessageHistory(key="langchain_messages") | |
print(msgs) | |
if "messages" not in st.session_state: | |
st.session_state.messages = [] | |
repo_id = "mistralai/Mistral-7B-Instruct-v0.2" | |
llm = HuggingFaceHub( | |
repo_id=repo_id, model_kwargs={"temperature": 0.1, "max_length": 128}) | |
# llm = load_llama2_llamaCpp() | |
qa_prompt = set_custom_prompt() | |
embeddings = load_embeddings() | |
uploaded_file = st.file_uploader('Choose your .pdf file', type="pdf") | |
if uploaded_file is not None : | |
with NamedTemporaryFile(dir='PDF', suffix='.pdf', delete=False) as f: | |
f.write(uploaded_file.getbuffer()) | |
print(f.name) | |
#filename = f.name | |
loader = PyPDFLoader(f.name) | |
pages = loader.load_and_split() | |
data.extend(pages) | |
#st.write(pages) | |
f.close() | |
os.unlink(f.name) | |
os.path.exists(f.name) | |
if len(data) > 0 : | |
embeddings = load_embeddings() | |
sp_docs = split_docs(documents = data) | |
st.write(f"This document have {len(sp_docs)} chunks") | |
sp_docs_list.extend(sp_docs) | |
try: | |
db = FAISS.from_documents(sp_docs_list, embeddings) | |
memory = ConversationBufferMemory(memory_key="chat_history", | |
return_messages=True, | |
input_key="query", | |
output_key="result") | |
qa_chain = RetrievalQA.from_chain_type( | |
llm = llm, | |
chain_type = "stuff", | |
retriever = db.as_retriever(search_kwargs = {'k':3}), | |
return_source_documents = True, | |
memory = memory, | |
chain_type_kwargs = {"prompt":qa_prompt}) | |
for message in st.session_state.messages: | |
with st.chat_message(message["role"]): | |
st.markdown(message["content"]) | |
# Accept user input | |
if query := st.chat_input("What is up?"): | |
# Display user message in chat message container | |
with st.chat_message("user"): | |
st.markdown(query) | |
# Add user message to chat history | |
st.session_state.messages.append({"role": "user", "content": query}) | |
start = time.time() | |
response = qa_chain({'query': query}) | |
with st.chat_message("assistant"): | |
st.markdown(response['result']) | |
end = time.time() | |
st.write("Respone time:",int(end-start),"sec") | |
print(response) | |
# Add assistant response to chat history | |
st.session_state.messages.append({"role": "assistant", "content": response['result']}) | |
with st.expander("See the related documents"): | |
for count, url in enumerate(response['source_documents']): | |
st.write(str(count+1)+":", url) | |
clear_button = st.button("Start new convo") | |
if clear_button : | |
st.session_state.messages = [] | |
qa_chain.memory.chat_memory.clear() | |
except: | |
st.write("Plaese upload your pdf file.") | |
if __name__ == '__main__': | |
main() |