Syed Junaid Iqbal commited on
Commit
9fc72bf
β€’
1 Parent(s): ed5b7e8

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +240 -0
app.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import subprocess
2
+
3
+
4
+ import streamlit as st
5
+ from dotenv import load_dotenv
6
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
7
+ from langchain.vectorstores import Chroma
8
+ from langchain.embeddings import FastEmbedEmbeddings # General embeddings from HuggingFace models.
9
+ from langchain.memory import ConversationBufferMemory
10
+ from langchain.chains import ConversationalRetrievalChain
11
+ from htmlTemplates import css, bot_template, user_template
12
+ from langchain.llms import LlamaCpp # For loading transformer models.
13
+ from langchain.document_loaders import PyPDFLoader, TextLoader, JSONLoader, CSVLoader
14
+ import tempfile
15
+ from langchain.chains import RetrievalQA
16
+ from langchain.prompts import PromptTemplate
17
+ import os
18
+ import glob
19
+
20
+
21
+ def get_pdf_text(pdf_docs):
22
+ temp_dir = tempfile.TemporaryDirectory()
23
+ temp_filepath = os.path.join(temp_dir.name, pdf_docs.name)
24
+
25
+ with open(temp_filepath, "wb") as f:
26
+ f.write(pdf_docs.getvalue())
27
+
28
+ pdf_loader = PyPDFLoader(temp_filepath)
29
+ pdf_doc = pdf_loader.load()
30
+ return pdf_doc
31
+
32
+
33
+ def get_text_file(text_docs):
34
+ temp_dir = tempfile.TemporaryDirectory()
35
+ temp_filepath = os.path.join(temp_dir.name, text_docs.name)
36
+
37
+ with open(temp_filepath, "wb") as f:
38
+ f.write(text_docs.getvalue())
39
+
40
+ text_loader = TextLoader(temp_filepath)
41
+ text_doc = text_loader.load()
42
+ return text_doc
43
+
44
+ def get_csv_file(csv_docs):
45
+ temp_dir = tempfile.TemporaryDirectory()
46
+ temp_filepath = os.path.join(temp_dir.name, csv_docs.name)
47
+
48
+ with open(temp_filepath, "wb") as f:
49
+ f.write(csv_docs.getvalue())
50
+
51
+ csv_loader = CSVLoader(temp_filepath)
52
+ csv_doc = csv_loader.load()
53
+ return csv_doc
54
+
55
+
56
+ def get_json_file(json_docs):
57
+ temp_dir = tempfile.TemporaryDirectory()
58
+ temp_filepath = os.path.join(temp_dir.name, json_docs.name)
59
+ with open(temp_filepath, "wb") as f:
60
+ f.write(json_docs.getvalue())
61
+
62
+ json_loader = JSONLoader(
63
+ file_path=temp_filepath,
64
+ jq_schema='.messages[].content',
65
+ text_content=False
66
+ )
67
+ json_doc = json_loader.load()
68
+ return json_doc
69
+
70
+
71
+ def get_text_chunks(documents):
72
+ text_splitter = RecursiveCharacterTextSplitter(
73
+ chunk_size=1000,
74
+ chunk_overlap=200,
75
+ length_function=len
76
+ )
77
+
78
+ documents = text_splitter.split_documents(documents)
79
+ return documents
80
+
81
+
82
+
83
+ def get_vectorstore(text_chunks, embeddings):
84
+ # embeddings = FastEmbedEmbeddings( model_name= "BAAI/bge-small-en-v1.5",
85
+ # cache_dir="./embedding_model/")
86
+
87
+ vectorstore = Chroma.from_documents(documents= text_chunks,
88
+ embedding= embeddings,
89
+ persist_directory= "./vectordb/")
90
+ return vectorstore
91
+
92
+ def get_conversation_chain(vectorstore):
93
+ # model_name_or_path = 'TheBloke/Llama-2-7B-chat-GGUF'
94
+ # model_basename = 'llama-2-7b-chat.Q2_K.gguf'
95
+ model_path = "./models/llama-2-13b-chat.Q4_K_S.gguf"
96
+
97
+ llm = LlamaCpp(model_path="./models/llama-2-13b-chat.Q4_K_S.gguf",
98
+ template = 0.4,
99
+ n_ctx=4000,
100
+ max_tokens=4000,
101
+ n_gpu_layers = 50,
102
+ n_batch = 512,
103
+ verbose=True)
104
+
105
+ memory = ConversationBufferMemory(
106
+ memory_key='chat_history', return_messages=True)
107
+
108
+ # prompt template πŸ“
109
+ template = """
110
+ You are a Experience human Resource Manager. When the employee asks you a question, you will have to refer the company policy and respond in a professional way. Make sure to sound Empethetic while being professional and sound like a Human!
111
+ Try to summarise the content and keep the answer to the point.
112
+ If you don't know the answer, just say that you don't know, don't try to make up an answer.
113
+ When generating answer for the given question make sure to follow the example template!
114
+ Example:
115
+ Question : how many paid leaves do i have ?
116
+ Answer : The number of paid leaves varies depending on the type of leave, like privilege leave you're entitled to a maximum of 21 days in a calendar year. Other leaves might have different entitlements. thanks for asking!
117
+ make sure to add "thanks for asking!" after every answer
118
+
119
+ {context}
120
+
121
+ Question: {question}
122
+ Answer:
123
+ """
124
+
125
+ rag_prompt_custom = PromptTemplate.from_template(template)
126
+
127
+ conversation_chain = RetrievalQA.from_chain_type(
128
+ llm,
129
+ retriever=vectorstore.as_retriever(),
130
+ chain_type_kwargs={"prompt": rag_prompt_custom},
131
+ memory = memory
132
+ )
133
+ return conversation_chain
134
+
135
+
136
+ def handle_userinput():
137
+
138
+ clear = False
139
+
140
+ # Add clear chat button
141
+ if st.button("Clear Chat history"):
142
+ clear = True
143
+ st.session_state.messages = []
144
+
145
+ if "messages" not in st.session_state:
146
+ st.session_state.messages = [{"role": "assistant", "content": "How can I help you?"}]
147
+
148
+ for msg in st.session_state.messages:
149
+ st.chat_message(msg["role"]).write(msg["content"])
150
+
151
+ if prompt := st.chat_input():
152
+ st.session_state.messages.append({"role": "user", "content": prompt})
153
+ st.chat_message("user").write(prompt)
154
+ if clear:
155
+ st.session_state.conversation.clean()
156
+ msg = st.session_state.conversation.run(prompt)
157
+ st.session_state.messages.append({"role": "assistant", "content": msg})
158
+ st.chat_message("assistant").write(msg)
159
+
160
+
161
+
162
+ # Function to apply rounded edges using CSS
163
+ def add_rounded_edges(image_path="./randstad_featuredimage.png", radius=30):
164
+ st.markdown(
165
+ f'<style>.rounded-img{{border-radius: {radius}px; overflow: hidden;}}</style>',
166
+ unsafe_allow_html=True,)
167
+ st.image(image_path, use_column_width=True, output_format='auto')
168
+
169
+
170
+ def main():
171
+ load_dotenv()
172
+ st.set_page_config(page_title="Chat with multiple Files",
173
+ page_icon=":books:")
174
+ st.write(css, unsafe_allow_html=True)
175
+
176
+ if "conversation" not in st.session_state:
177
+ st.session_state.conversation = None
178
+ if "chat_history" not in st.session_state:
179
+ st.session_state.chat_history = None
180
+
181
+ st.title("πŸ’¬ Randstad HR Chatbot")
182
+ st.subheader("πŸš€ A HR powered by Generative AI")
183
+ # user_question = st.text_input("Ask a question about your documents:")
184
+
185
+ st.session_state.embeddings = embeddings = FastEmbedEmbeddings( model_name= "BAAI/bge-small-en-v1.5", cache_dir="./embedding_model/")
186
+
187
+ if len(glob.glob("./vectordb/*.sqlite3")) > 0:
188
+
189
+ vectorstore = Chroma(persist_directory="./vectordb/", embedding_function=st.session_state.embeddings)
190
+ st.session_state.conversation = get_conversation_chain(vectorstore)
191
+ handle_userinput()
192
+
193
+ with st.sidebar:
194
+ add_rounded_edges()
195
+
196
+ st.subheader("Your documents")
197
+ docs = st.file_uploader(
198
+ "Upload File (pdf,text,csv...) and click 'Process'", accept_multiple_files=True)
199
+ if st.button("Process"):
200
+ with st.spinner("Processing"):
201
+ # get pdf text
202
+ doc_list = []
203
+
204
+ for file in docs:
205
+ print('file - type : ', file.type)
206
+ if file.type == 'text/plain':
207
+ # file is .txt
208
+ doc_list.extend(get_text_file(file))
209
+ elif file.type in ['application/octet-stream', 'application/pdf']:
210
+ # file is .pdf
211
+ doc_list.extend(get_pdf_text(file))
212
+ elif file.type == 'text/csv':
213
+ # file is .csv
214
+ doc_list.extend(get_csv_file(file))
215
+ elif file.type == 'application/json':
216
+ # file is .json
217
+ doc_list.extend(get_json_file(file))
218
+
219
+ # get the text chunks
220
+ text_chunks = get_text_chunks(doc_list)
221
+
222
+ # create vector store
223
+ vectorstore = get_vectorstore(text_chunks, st.session_state.embeddings)
224
+
225
+ # create conversation chain
226
+ st.session_state.conversation = get_conversation_chain(vectorstore)
227
+
228
+
229
+ if __name__ == '__main__':
230
+ # Define the command
231
+ command = 'CMAKE_ARGS="-DLLAMA_CUBLAS=on" FORCE_CMAKE=1 pip install llama-cpp-python --no-cache-dir'
232
+
233
+ # Run the command using subprocess
234
+ try:
235
+ subprocess.run(command, shell=True, check=True)
236
+ print("Command executed successfully.")
237
+ except subprocess.CalledProcessError as e:
238
+ print(f"Error: {e}")
239
+
240
+ main()