Spaces:
Sleeping
Sleeping
KrishnaKumar23
commited on
Commit
β’
e797f63
1
Parent(s):
028692e
Changed LLM to Mixtral-8x7B-Instruct-v0.1
Browse files- app.py +165 -45
- llm_model.py +104 -83
- requirements.txt +2 -0
- static/temp.txt +0 -0
app.py
CHANGED
@@ -3,11 +3,21 @@ from streamlit_lottie import st_lottie
|
|
3 |
import fitz # PyMuPDF
|
4 |
import requests
|
5 |
import os, shutil
|
6 |
-
import sidebar
|
7 |
import llm_model
|
8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
@st.cache_data(experimental_allow_widgets=True)
|
10 |
-
def index_document(uploaded_file):
|
11 |
|
12 |
if uploaded_file is not None:
|
13 |
# Specify the folder path where you want to store the uploaded file in the 'assets' folder
|
@@ -24,8 +34,9 @@ def index_document(uploaded_file):
|
|
24 |
st.success(f"File '{file_name}' uploaded !")
|
25 |
|
26 |
with st.spinner("Indexing document... This is a free CPU version and may take a while β³"):
|
27 |
-
|
28 |
-
|
|
|
29 |
return file_name
|
30 |
else:
|
31 |
return None
|
@@ -44,11 +55,135 @@ def is_query_valid(query: str) -> bool:
|
|
44 |
return False
|
45 |
return True
|
46 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
|
48 |
# Function to load model parameters
|
49 |
@st.cache_resource()
|
50 |
def load_model():
|
51 |
-
|
|
|
52 |
|
53 |
st.set_page_config(page_title="Document QA Bot")
|
54 |
lottie_book = load_lottieurl("https://assets4.lottiefiles.com/temp/lf20_aKAfIn.json")
|
@@ -56,44 +191,29 @@ st_lottie(lottie_book, speed=1, height=200, key="initial")
|
|
56 |
# Place the title below the Lottie animation
|
57 |
st.title("Document Q&A Bot π€")
|
58 |
|
|
|
|
|
59 |
# Left Sidebar
|
60 |
-
sidebar
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
# Output Columns
|
84 |
-
answer_col, sources_col = st.columns(2)
|
85 |
-
|
86 |
-
qa_chain = llm_model.document_parser(instructor_embeddings, llm)
|
87 |
-
result = qa_chain(query)
|
88 |
-
|
89 |
-
with answer_col:
|
90 |
-
st.markdown("#### Answer")
|
91 |
-
st.markdown(result["result"])
|
92 |
-
|
93 |
-
with sources_col:
|
94 |
-
st.markdown("#### Sources")
|
95 |
-
if not ("i don't know" in result["result"].lower()):
|
96 |
-
for source in result["source_documents"]:
|
97 |
-
st.markdown(source.page_content)
|
98 |
-
st.markdown(source.metadata["source"])
|
99 |
-
st.markdown("--------------------------")
|
|
|
3 |
import fitz # PyMuPDF
|
4 |
import requests
|
5 |
import os, shutil
|
|
|
6 |
import llm_model
|
7 |
|
8 |
+
|
9 |
+
SYSTEM_PROMPT = [
|
10 |
+
"""
|
11 |
+
You are not Mistral AI, but rather a Q&A bot trained by Krishna Kumar while building a cool side project based on RAG. Whenever asked, you need to answer as Q&A bot.
|
12 |
+
""",
|
13 |
+
"""You are a RAG based Document Q&A bot. Based on the input prompt and retrieved context from the vector database you will answer questions that are closer to the context.
|
14 |
+
If no context was found then, say "I don't know" instead of making up answer on your own. Follow above rules strictly.
|
15 |
+
"""
|
16 |
+
]
|
17 |
+
|
18 |
+
|
19 |
@st.cache_data(experimental_allow_widgets=True)
|
20 |
+
def index_document(_llm_object, uploaded_file):
|
21 |
|
22 |
if uploaded_file is not None:
|
23 |
# Specify the folder path where you want to store the uploaded file in the 'assets' folder
|
|
|
34 |
st.success(f"File '{file_name}' uploaded !")
|
35 |
|
36 |
with st.spinner("Indexing document... This is a free CPU version and may take a while β³"):
|
37 |
+
retriever = _llm_object.create_vector_db(file_name)
|
38 |
+
st.session_state.retriever = retriever
|
39 |
+
|
40 |
return file_name
|
41 |
else:
|
42 |
return None
|
|
|
55 |
return False
|
56 |
return True
|
57 |
|
58 |
+
def init_state() :
|
59 |
+
if "filename" not in st.session_state:
|
60 |
+
st.session_state.filename = None
|
61 |
+
|
62 |
+
if "messages" not in st.session_state:
|
63 |
+
st.session_state.messages = []
|
64 |
+
|
65 |
+
if "temp" not in st.session_state:
|
66 |
+
st.session_state.temp = 0.7
|
67 |
+
|
68 |
+
if "history" not in st.session_state:
|
69 |
+
st.session_state.history = [SYSTEM_PROMPT]
|
70 |
+
|
71 |
+
if "repetion_penalty" not in st.session_state :
|
72 |
+
st.session_state.repetion_penalty = 1
|
73 |
+
|
74 |
+
if "chat_bot" not in st.session_state :
|
75 |
+
st.session_state.chat_bot = "Mixtral-8x7B-Instruct-v0.1"
|
76 |
+
|
77 |
+
|
78 |
+
def faq():
|
79 |
+
st.markdown(
|
80 |
+
"""
|
81 |
+
# FAQ
|
82 |
+
## How does Document Q&A Bot work?
|
83 |
+
When you upload a document (in Pdf, word, csv or txt format), it will be divided into smaller chunks
|
84 |
+
and stored in a special type of database called a vector index
|
85 |
+
that allows for semantic search and retrieval.
|
86 |
+
|
87 |
+
When you ask a question, our Q&A bot will first look through the document chunks and find the
|
88 |
+
most relevant ones using the vector index. This acts as a context to our custom prompt which will be feed to the LLM model.
|
89 |
+
If the context was not found in the document then, LLM will reply 'I don't know'
|
90 |
+
|
91 |
+
## Is my data safe?
|
92 |
+
Yes, your data is safe. Our bot does not store your documents or
|
93 |
+
questions. All uploaded data is deleted after you close the browser tab.
|
94 |
+
|
95 |
+
## Why does it take so long to index my document?
|
96 |
+
Since, this is a sample QA bot project that uses open-source model
|
97 |
+
and doesn't have much resource capabilities like GPU, it may take time
|
98 |
+
to index your document based on the size of the document.
|
99 |
+
|
100 |
+
## Are the answers 100% accurate?
|
101 |
+
No, the answers are not 100% accurate.
|
102 |
+
But for most use cases, our QA bot is very accurate and can answer
|
103 |
+
most questions. Always check with the sources to make sure that the answers
|
104 |
+
are correct.
|
105 |
+
"""
|
106 |
+
)
|
107 |
+
|
108 |
+
|
109 |
+
def sidebar():
|
110 |
+
with st.sidebar:
|
111 |
+
st.markdown("## Document Q&A Bot")
|
112 |
+
st.write("LLM: Mixtral-8x7B-Instruct-v0.1")
|
113 |
+
#st.success('API key already provided!', icon='β
')
|
114 |
+
|
115 |
+
st.markdown("### Set Model Parameters")
|
116 |
+
# select LLM model
|
117 |
+
st.session_state.model_name = 'Mixtral-8x7B-Instruct-v0.1'
|
118 |
+
# set model temperature
|
119 |
+
st.session_state.temperature = st.slider(label="Temperature", min_value=0.0, max_value=1.0, step=0.1, value=0.7)
|
120 |
+
st.session_state.top_p = st.slider(label="Top Probablity", min_value=0.0, max_value=1.0, step=0.1, value=0.95)
|
121 |
+
st.session_state.repetition_penalty = st.slider(label="Repetition Penalty", min_value=0.0, max_value=1.0, step=0.1, value=1.0)
|
122 |
+
|
123 |
+
# load model parameters
|
124 |
+
st.session_state.llm_object = load_model()
|
125 |
+
st.markdown("---")
|
126 |
+
# Upload file through Streamlit
|
127 |
+
st.session_state.uploaded_file = st.file_uploader("Upload a file", type=["pdf", "doc", "docx", "txt"])
|
128 |
+
index_document(st.session_state.llm_object, st.session_state.uploaded_file)
|
129 |
+
|
130 |
+
st.markdown("---")
|
131 |
+
st.markdown("# About")
|
132 |
+
st.markdown(
|
133 |
+
"""QA bot π€ allows you to ask questions about your
|
134 |
+
documents and get accurate answers with citations."""
|
135 |
+
)
|
136 |
+
|
137 |
+
st.markdown("Created with β€οΈ by Krishna Kumar Yadav")
|
138 |
+
st.markdown(
|
139 |
+
"""
|
140 |
+
- [Email](mailto:krishna158@live.com)
|
141 |
+
- [LinkedIn](https://www.linkedin.com/in/krishna-kumar-yadav-726831105/)
|
142 |
+
- [Github](https://github.com/krish-yadav23)
|
143 |
+
- [LeetCode](https://leetcode.com/KrishnaKumar23/)
|
144 |
+
"""
|
145 |
+
)
|
146 |
+
|
147 |
+
faq()
|
148 |
+
|
149 |
+
|
150 |
+
def chat_box() :
|
151 |
+
for message in st.session_state.messages:
|
152 |
+
with st.chat_message(message["role"]):
|
153 |
+
st.markdown(message["content"])
|
154 |
+
|
155 |
+
|
156 |
+
def generate_chat_stream(prompt) :
|
157 |
+
|
158 |
+
with st.spinner("Fetching relevant answers from source document..."):
|
159 |
+
response, sources = st.session_state.llm_object.mixtral_chat_inference(prompt, st.session_state.history, st.session_state.temperature,
|
160 |
+
st.session_state.top_p, st.session_state.repetition_penalty, st.session_state.retriever)
|
161 |
+
|
162 |
+
|
163 |
+
return response, sources
|
164 |
+
|
165 |
+
def stream_handler(chat_stream, placeholder) :
|
166 |
+
full_response = ''
|
167 |
+
|
168 |
+
for chunk in chat_stream :
|
169 |
+
if chunk.token.text!='</s>' :
|
170 |
+
full_response += chunk.token.text
|
171 |
+
placeholder.markdown(full_response + "β")
|
172 |
+
placeholder.markdown(full_response)
|
173 |
+
|
174 |
+
return full_response
|
175 |
+
|
176 |
+
def show_source(sources) :
|
177 |
+
with st.expander("Show source") :
|
178 |
+
for source in sources:
|
179 |
+
st.info(f"{source}")
|
180 |
+
|
181 |
|
182 |
# Function to load model parameters
|
183 |
@st.cache_resource()
|
184 |
def load_model():
|
185 |
+
# create llm object
|
186 |
+
return llm_model.LlmModel()
|
187 |
|
188 |
st.set_page_config(page_title="Document QA Bot")
|
189 |
lottie_book = load_lottieurl("https://assets4.lottiefiles.com/temp/lf20_aKAfIn.json")
|
|
|
191 |
# Place the title below the Lottie animation
|
192 |
st.title("Document Q&A Bot π€")
|
193 |
|
194 |
+
# initialize session state for streamlit app
|
195 |
+
init_state()
|
196 |
# Left Sidebar
|
197 |
+
sidebar()
|
198 |
+
chat_box()
|
199 |
+
|
200 |
+
if prompt := st.chat_input("Ask a question about your document!"):
|
201 |
+
st.chat_message("user").markdown(prompt)
|
202 |
+
st.session_state.messages.append({"role": "user", "content": prompt})
|
203 |
+
|
204 |
+
try:
|
205 |
+
chat_stream, sources = generate_chat_stream(prompt)
|
206 |
+
|
207 |
+
with st.chat_message("assistant"):
|
208 |
+
placeholder = st.empty()
|
209 |
+
full_response = stream_handler(chat_stream, placeholder)
|
210 |
+
show_source(sources)
|
211 |
+
|
212 |
+
st.session_state.history.append([prompt, full_response])
|
213 |
+
st.session_state.messages.append({"role": "assistant", "content": full_response})
|
214 |
+
except Exception as e:
|
215 |
+
if not st.session_state.uploaded_file:
|
216 |
+
st.error("Kindly provide the document file by uploading it before posing any questions. Your cooperation is appreciated!")
|
217 |
+
else:
|
218 |
+
st.error(e)
|
219 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
llm_model.py
CHANGED
@@ -1,92 +1,113 @@
|
|
1 |
from langchain.vectorstores import FAISS
|
2 |
-
from langchain.llms import GooglePalm
|
3 |
-
from langchain.document_loaders import PyPDFLoader
|
4 |
-
from langchain.
|
5 |
-
from langchain.document_loaders import Docx2txtLoader
|
6 |
-
from langchain.embeddings import HuggingFaceInstructEmbeddings
|
7 |
from langchain.prompts import PromptTemplate
|
8 |
from langchain.chains import RetrievalQA
|
9 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
|
|
10 |
import os
|
11 |
from dotenv import load_dotenv
|
12 |
|
13 |
vector_index_path = "assets/vectordb/faiss_index"
|
14 |
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from langchain.vectorstores import FAISS
|
2 |
+
#from langchain.llms import GooglePalm, CTransformers
|
3 |
+
from langchain.document_loaders import PyPDFLoader, TextLoader, Docx2txtLoader
|
4 |
+
from langchain.embeddings import HuggingFaceEmbeddings, HuggingFaceInstructEmbeddings
|
|
|
|
|
5 |
from langchain.prompts import PromptTemplate
|
6 |
from langchain.chains import RetrievalQA
|
7 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
8 |
+
from huggingface_hub import InferenceClient
|
9 |
import os
|
10 |
from dotenv import load_dotenv
|
11 |
|
12 |
vector_index_path = "assets/vectordb/faiss_index"
|
13 |
|
14 |
+
class LlmModel:
|
15 |
+
|
16 |
+
def __init__(self):
|
17 |
+
# load dot env variables
|
18 |
+
self.load_env_variables()
|
19 |
+
# load llm model
|
20 |
+
self.hf_embeddings = self.load_huggingface_embeddings()
|
21 |
+
|
22 |
+
def load_env_variables(self):
|
23 |
+
load_dotenv() # take environment variables from .env
|
24 |
+
|
25 |
+
def custom_prompt(self, question, history, context):
|
26 |
+
#RAG prompt template
|
27 |
+
prompt = "<s>"
|
28 |
+
for user_prompt, bot_response in history: # provide chat history
|
29 |
+
prompt += f"[INST] {user_prompt} [/INST]"
|
30 |
+
prompt += f" {bot_response}</s>"
|
31 |
+
|
32 |
+
message_prompt = f"""
|
33 |
+
You are a question answer agent and you must strictly follow below prompt template.
|
34 |
+
Given the following context and a question, generate an answer based on this context only.
|
35 |
+
Keep answers brief and well-structured. Do not give one word answers.
|
36 |
+
If the answer is not found in the context, kindly state "I don't know." Don't try to make up an answer.
|
37 |
+
|
38 |
+
CONTEXT: {context}
|
39 |
+
|
40 |
+
QUESTION: {question}
|
41 |
+
"""
|
42 |
+
prompt += f"[INST] {message_prompt} [/INST]"
|
43 |
+
|
44 |
+
return prompt
|
45 |
+
|
46 |
+
def format_sources(self, sources):
|
47 |
+
# format the document sources
|
48 |
+
source_results = []
|
49 |
+
for source in sources:
|
50 |
+
source_results.append(str(source.page_content) +
|
51 |
+
"\n Document: " + str(source.metadata['source']) +
|
52 |
+
" Page: " + str(source.metadata['page']))
|
53 |
+
return source_results
|
54 |
+
|
55 |
+
def mixtral_chat_inference(self, prompt, history, temperature, top_p, repetition_penalty, retriever):
|
56 |
+
|
57 |
+
context = retriever.get_relevant_documents(prompt)
|
58 |
+
sources = self.format_sources(context)
|
59 |
+
# use hugging face infrence api
|
60 |
+
client = InferenceClient("mistralai/Mistral-7B-Instruct-v0.1",
|
61 |
+
token=os.environ["HF_TOKEN"]
|
62 |
+
)
|
63 |
+
temperature = float(temperature)
|
64 |
+
if temperature < 1e-2:
|
65 |
+
temperature = 1e-2
|
66 |
+
|
67 |
+
generate_kwargs = dict(
|
68 |
+
temperature = temperature,
|
69 |
+
max_new_tokens = 512,
|
70 |
+
top_p = top_p,
|
71 |
+
repetition_penalty = repetition_penalty,
|
72 |
+
do_sample = True
|
73 |
+
)
|
74 |
+
|
75 |
+
formatted_prompt = self.custom_prompt(prompt, history, context)
|
76 |
+
|
77 |
+
return client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False), sources
|
78 |
+
|
79 |
+
|
80 |
+
|
81 |
+
def load_huggingface_embeddings(self):
|
82 |
+
# Initialize instructor embeddings using the Hugging Face model
|
83 |
+
#return HuggingFaceInstructEmbeddings(model_name="hkunlp/instructor-large")
|
84 |
+
return HuggingFaceEmbeddings(model_name = "sentence-transformers/all-MiniLM-L6-v2",
|
85 |
+
model_kwargs={'device': 'cpu'})
|
86 |
+
|
87 |
+
|
88 |
+
|
89 |
+
def create_vector_db(self, filename):
|
90 |
+
|
91 |
+
if filename.endswith(".pdf"):
|
92 |
+
loader = PyPDFLoader(file_path=filename)
|
93 |
+
elif filename.endswith(".doc") or filename.endswith(".docx"):
|
94 |
+
loader = Docx2txtLoader(filename)
|
95 |
+
elif filename.endswith("txt") or filename.endswith("TXT"):
|
96 |
+
loader = TextLoader(filename)
|
97 |
+
|
98 |
+
# Split documents
|
99 |
+
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
|
100 |
+
splits = text_splitter.split_documents(loader.load())
|
101 |
+
|
102 |
+
# Create a FAISS instance for vector database from 'data'
|
103 |
+
vectordb = FAISS.from_documents(documents = splits,
|
104 |
+
embedding = self.hf_embeddings)
|
105 |
+
|
106 |
+
# Save vector database locally
|
107 |
+
#vectordb.save_local(vector_index_path)
|
108 |
+
|
109 |
+
# set vectordb content
|
110 |
+
# Load the vector database from the local folder
|
111 |
+
#vectordb = FAISS.load_local(vector_index_path, self.hf_embeddings)
|
112 |
+
# Create a retriever for querying the vector database
|
113 |
+
return vectordb.as_retriever(search_type="similarity")
|
requirements.txt
CHANGED
@@ -13,3 +13,5 @@ frontend
|
|
13 |
tools
|
14 |
docx2txt
|
15 |
fitz
|
|
|
|
|
|
13 |
tools
|
14 |
docx2txt
|
15 |
fitz
|
16 |
+
huggingface_hub
|
17 |
+
chainlit
|
static/temp.txt
DELETED
File without changes
|