arjunanand13 commited on
Commit
60cf7bb
1 Parent(s): 5c9cfbf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +181 -23
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import os
2
- import json
3
  from langchain.document_loaders import TextLoader, DirectoryLoader
 
4
  from langchain.vectorstores import FAISS
5
  from sentence_transformers import SentenceTransformer
6
  import faiss
@@ -8,7 +8,7 @@ import torch
8
  import numpy as np
9
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, BitsAndBytesConfig
10
  from datetime import datetime
11
- import gradio as gr
12
 
13
  class DocumentRetrievalAndGeneration:
14
  def __init__(self, embedding_model_name, lm_model_id, data_folder, faiss_index_path):
@@ -16,6 +16,7 @@ class DocumentRetrievalAndGeneration:
16
  self.embeddings = SentenceTransformer(embedding_model_name)
17
  self.gpu_index = self.load_faiss_index(faiss_index_path)
18
  self.llm = self.initialize_llm(lm_model_id)
 
19
 
20
  def load_documents(self, folder_path):
21
  loader = DirectoryLoader(folder_path, loader_cls=TextLoader)
@@ -49,6 +50,11 @@ class DocumentRetrievalAndGeneration:
49
  )
50
  return generate_text
51
 
 
 
 
 
 
52
  def query_and_generate_response(self, query):
53
  query_embedding = self.embeddings.encode(query, convert_to_tensor=True).cpu().numpy()
54
  distances, indices = self.gpu_index.search(np.array([query_embedding]), k=5)
@@ -56,28 +62,11 @@ class DocumentRetrievalAndGeneration:
56
  content = ""
57
  for idx in indices[0]:
58
  content += "-" * 50 + "\n"
59
- content += self.documents[idx].page_content + "\n"
60
- print(self.documents[idx].page_content)
61
  print("############################")
62
- prompt=f"""
63
- You are a knowledgeable assistant with access to a comprehensive database.
64
- I need you to answer my question and provide related information in a specific format.
65
- I have provided five relatable json files {content}, choose the most suitable chunks for answering the query
66
- Here's what I need:
67
- Include a final answer without additional comments, sign-offs, or extra phrases. Be direct and to the point.
68
- content
69
- Here's my question:
70
- Query:{query}
71
- Solution==>
72
- Example1
73
- Query: "How to use IPU1_0 instead of A15_0 to process NDK in TDA2x-EVM",
74
- Solution: "To use IPU1_0 instead of A15_0 to process NDK in TDA2x-EVM, you need to modify the configuration file of the NDK application. Specifically, change the processor reference from 'A15_0' to 'IPU1_0'.",
75
 
76
- Example2
77
- Query: "Can BQ25896 support I2C interface?",
78
- Solution: "Yes, the BQ25896 charger supports the I2C interface for communication.",
79
- """
80
- # prompt = f"Query: {query}\nSolution: {content}\n"
81
 
82
  # Encode and prepare inputs
83
  messages = [{"role": "user", "content": prompt}]
@@ -96,7 +85,7 @@ class DocumentRetrievalAndGeneration:
96
  print("Time elapsed:", elapsed_time)
97
  print("Device in use:", self.llm.device)
98
 
99
- return generated_response,content
100
 
101
  def qa_infer_gradio(self, query):
102
  response = self.query_and_generate_response(query)
@@ -156,3 +145,172 @@ if __name__ == "__main__":
156
 
157
  # Launch the interface
158
  launch_interface()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
 
2
  from langchain.document_loaders import TextLoader, DirectoryLoader
3
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
4
  from langchain.vectorstores import FAISS
5
  from sentence_transformers import SentenceTransformer
6
  import faiss
 
8
  import numpy as np
9
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, BitsAndBytesConfig
10
  from datetime import datetime
11
+ import json
12
 
13
  class DocumentRetrievalAndGeneration:
14
  def __init__(self, embedding_model_name, lm_model_id, data_folder, faiss_index_path):
 
16
  self.embeddings = SentenceTransformer(embedding_model_name)
17
  self.gpu_index = self.load_faiss_index(faiss_index_path)
18
  self.llm = self.initialize_llm(lm_model_id)
19
+ self.all_splits = self.split_documents()
20
 
21
  def load_documents(self, folder_path):
22
  loader = DirectoryLoader(folder_path, loader_cls=TextLoader)
 
50
  )
51
  return generate_text
52
 
53
+ def split_documents(self):
54
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=5000, chunk_overlap=250)
55
+ all_splits = text_splitter.split_documents(self.documents)
56
+ return all_splits
57
+
58
  def query_and_generate_response(self, query):
59
  query_embedding = self.embeddings.encode(query, convert_to_tensor=True).cpu().numpy()
60
  distances, indices = self.gpu_index.search(np.array([query_embedding]), k=5)
 
62
  content = ""
63
  for idx in indices[0]:
64
  content += "-" * 50 + "\n"
65
+ content += self.all_splits[idx].page_content + "\n"
66
+ print(self.all_splits[idx].page_content)
67
  print("############################")
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
+ prompt = f"Query: {query}\nSolution: {content}\n"
 
 
 
 
70
 
71
  # Encode and prepare inputs
72
  messages = [{"role": "user", "content": prompt}]
 
85
  print("Time elapsed:", elapsed_time)
86
  print("Device in use:", self.llm.device)
87
 
88
+ return generated_response, content
89
 
90
  def qa_infer_gradio(self, query):
91
  response = self.query_and_generate_response(query)
 
145
 
146
  # Launch the interface
147
  launch_interface()
148
+
149
+
150
+
151
+ # import os
152
+ # import json
153
+ # from langchain.document_loaders import TextLoader, DirectoryLoader
154
+ # from langchain.vectorstores import FAISS
155
+ # from sentence_transformers import SentenceTransformer
156
+ # import faiss
157
+ # import torch
158
+ # import numpy as np
159
+ # from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, BitsAndBytesConfig
160
+ # from datetime import datetime
161
+ # import gradio as gr
162
+
163
+ # class DocumentRetrievalAndGeneration:
164
+ # def __init__(self, embedding_model_name, lm_model_id, data_folder, faiss_index_path):
165
+ # self.documents = self.load_documents(data_folder)
166
+ # self.embeddings = SentenceTransformer(embedding_model_name)
167
+ # self.gpu_index = self.load_faiss_index(faiss_index_path)
168
+ # self.llm = self.initialize_llm(lm_model_id)
169
+
170
+ # def load_documents(self, folder_path):
171
+ # loader = DirectoryLoader(folder_path, loader_cls=TextLoader)
172
+ # documents = loader.load()
173
+ # print('Length of documents:', len(documents))
174
+ # return documents
175
+
176
+ # def load_faiss_index(self, faiss_index_path):
177
+ # cpu_index = faiss.read_index(faiss_index_path)
178
+ # gpu_resource = faiss.StandardGpuResources()
179
+ # gpu_index = faiss.index_cpu_to_gpu(gpu_resource, 0, cpu_index)
180
+ # return gpu_index
181
+
182
+ # def initialize_llm(self, model_id):
183
+ # bnb_config = BitsAndBytesConfig(
184
+ # load_in_4bit=True,
185
+ # bnb_4bit_use_double_quant=True,
186
+ # bnb_4bit_quant_type="nf4",
187
+ # bnb_4bit_compute_dtype=torch.bfloat16
188
+ # )
189
+ # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
190
+ # model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config)
191
+ # tokenizer = AutoTokenizer.from_pretrained(model_id)
192
+ # generate_text = pipeline(
193
+ # model=model,
194
+ # tokenizer=tokenizer,
195
+ # return_full_text=True,
196
+ # task='text-generation',
197
+ # temperature=0.6,
198
+ # max_new_tokens=2048,
199
+ # )
200
+ # return generate_text
201
+
202
+ # def query_and_generate_response(self, query):
203
+ # query_embedding = self.embeddings.encode(query, convert_to_tensor=True).cpu().numpy()
204
+ # distances, indices = self.gpu_index.search(np.array([query_embedding]), k=5)
205
+
206
+ # # content = ""
207
+ # # for idx in indices[0]:
208
+ # # content += "-" * 50 + "\n"
209
+ # # content += self.documents[idx].page_content + "\n"
210
+ # # print(self.documents[idx].page_content)
211
+ # # print("############################")
212
+ # content = ""
213
+ # all_splits=build_faiss_index.all_splits
214
+ # for idx in indices[0]:
215
+ # content += "-" * 50 + "\n"
216
+
217
+ # content+=all_splits[idx].page_content
218
+ # print(all_splits[idx].page_content)
219
+ # print("############################")
220
+ # prompt=f"""
221
+ # You are a knowledgeable assistant with access to a comprehensive database.
222
+ # I need you to answer my question and provide related information in a specific format.
223
+ # I have provided five relatable json files {content}, choose the most suitable chunks for answering the query
224
+ # Here's what I need:
225
+ # Include a final answer without additional comments, sign-offs, or extra phrases. Be direct and to the point.
226
+ # content
227
+ # Here's my question:
228
+ # Query:{query}
229
+ # Solution==>
230
+ # Example1
231
+ # Query: "How to use IPU1_0 instead of A15_0 to process NDK in TDA2x-EVM",
232
+ # Solution: "To use IPU1_0 instead of A15_0 to process NDK in TDA2x-EVM, you need to modify the configuration file of the NDK application. Specifically, change the processor reference from 'A15_0' to 'IPU1_0'.",
233
+
234
+ # Example2
235
+ # Query: "Can BQ25896 support I2C interface?",
236
+ # Solution: "Yes, the BQ25896 charger supports the I2C interface for communication.",
237
+ # """
238
+ # # prompt = f"Query: {query}\nSolution: {content}\n"
239
+
240
+ # # Encode and prepare inputs
241
+ # messages = [{"role": "user", "content": prompt}]
242
+ # encodeds = self.llm.tokenizer.apply_chat_template(messages, return_tensors="pt")
243
+ # model_inputs = encodeds.to(self.llm.device)
244
+
245
+ # # Perform inference and measure time
246
+ # start_time = datetime.now()
247
+ # generated_ids = self.llm.model.generate(model_inputs, max_new_tokens=1000, do_sample=True)
248
+ # elapsed_time = datetime.now() - start_time
249
+
250
+ # # Decode and return output
251
+ # decoded = self.llm.tokenizer.batch_decode(generated_ids)
252
+ # generated_response = decoded[0]
253
+ # print("Generated response:", generated_response)
254
+ # print("Time elapsed:", elapsed_time)
255
+ # print("Device in use:", self.llm.device)
256
+
257
+ # return generated_response,content
258
+
259
+ # def qa_infer_gradio(self, query):
260
+ # response = self.query_and_generate_response(query)
261
+ # return response
262
+
263
+ # if __name__ == "__main__":
264
+ # # Example usage
265
+ # embedding_model_name = 'flax-sentence-embeddings/all_datasets_v3_MiniLM-L12'
266
+ # lm_model_id = "mistralai/Mistral-7B-Instruct-v0.2"
267
+ # data_folder = 'sample_embedding_folder'
268
+ # faiss_index_path = 'faiss_index_new_model3.index'
269
+
270
+ # doc_retrieval_gen = DocumentRetrievalAndGeneration(embedding_model_name, lm_model_id, data_folder, faiss_index_path)
271
+
272
+ # # Define Gradio interface function
273
+ # def launch_interface():
274
+ # css_code = """
275
+ # .gradio-container {
276
+ # background-color: #daccdb;
277
+ # }
278
+ # /* Button styling for all buttons */
279
+ # button {
280
+ # background-color: #927fc7; /* Default color for all other buttons */
281
+ # color: black;
282
+ # border: 1px solid black;
283
+ # padding: 10px;
284
+ # margin-right: 10px;
285
+ # font-size: 16px; /* Increase font size */
286
+ # font-weight: bold; /* Make text bold */
287
+ # }
288
+ # """
289
+ # EXAMPLES = ["Does the VIP modules & CSI2 module could work simultaneously? ",
290
+ # "I'm using Code Composer Studio 5.4.0.00091 and enabled FPv4SPD16 floating point support for CortexM4 in TDA2. However, after building the project, the .asm file shows --float_support=vfplib instead of FPv4SPD16. Why is this happening?",
291
+ # "Could you clarify the maximum number of cameras that can be connected simultaneously to the video input ports on the TDA2x SoC, considering it supports up to 10 multiplexed input ports and includes 3 dedicated video input modules?"]
292
+
293
+ # file_path = "ticketNames.txt"
294
+
295
+ # # Read the file content
296
+ # with open(file_path, "r") as file:
297
+ # content = file.read()
298
+ # ticket_names = json.loads(content)
299
+ # dropdown = gr.Dropdown(label="Sample queries", choices=ticket_names)
300
+
301
+ # # Define Gradio interface
302
+ # interface = gr.Interface(
303
+ # fn=doc_retrieval_gen.qa_infer_gradio,
304
+ # inputs=[gr.Textbox(label="QUERY", placeholder="Enter your query here")],
305
+ # allow_flagging='never',
306
+ # examples=EXAMPLES,
307
+ # cache_examples=False,
308
+ # outputs=[gr.Textbox(label="SOLUTION"), gr.Textbox(label="RELATED QUERIES")],
309
+ # css=css_code
310
+ # )
311
+
312
+ # # Launch Gradio interface
313
+ # interface.launch(debug=True)
314
+
315
+ # # Launch the interface
316
+ # launch_interface()