syedmudassir16 commited on
Commit
3cba93e
1 Parent(s): 7af7fc3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +119 -110
app.py CHANGED
@@ -1,9 +1,13 @@
1
  import os
2
  import multiprocessing
3
  import concurrent.futures
4
- from langchain_community.document_loaders import TextLoader, DirectoryLoader
5
  from langchain.text_splitter import RecursiveCharacterTextSplitter
6
- from langchain_community.embeddings import HuggingFaceEmbeddings
 
 
 
 
7
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig
8
  from datetime import datetime
9
  import json
@@ -11,30 +15,11 @@ import gradio as gr
11
  import re
12
  from threading import Thread
13
  from transformers.agents import Tool, HfEngine, ReactJsonAgent
14
- from huggingface_hub import InferenceClient
15
- import logging
16
- import torch
17
- import numpy as np
18
- import faiss
19
- import warnings
20
-
21
- # Suppress specific warnings
22
- warnings.filterwarnings("ignore", category=FutureWarning)
23
-
24
- logging.basicConfig(level=logging.INFO)
25
- logger = logging.getLogger(__name__)
26
-
27
- try:
28
- from langchain_community.vectorstores import FAISS
29
- except ImportError:
30
- logger.error("Failed to import FAISS. Make sure it's installed correctly.")
31
- logger.info("You can try: pip install faiss-cpu --no-cache")
32
- FAISS = None
33
 
34
  class DocumentRetrievalAndGeneration:
35
  def __init__(self, embedding_model_name, lm_model_id, data_folder):
36
  self.all_splits = self.load_documents(data_folder)
37
- self.embeddings = HuggingFaceEmbeddings(model_name=embedding_model_name)
38
  self.gpu_index = self.create_faiss_index()
39
  self.tokenizer, self.model = self.initialize_llm(lm_model_id)
40
  self.retriever_tool = self.create_retriever_tool()
@@ -45,15 +30,17 @@ class DocumentRetrievalAndGeneration:
45
  documents = loader.load()
46
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=5000, chunk_overlap=250)
47
  all_splits = text_splitter.split_documents(documents)
48
- logger.info(f'Loaded {len(documents)} documents')
49
- logger.info(f"Split into {len(all_splits)} chunks")
 
 
50
  return all_splits
51
 
52
  def create_faiss_index(self):
53
  all_texts = [split.page_content for split in self.all_splits]
54
- embeddings = self.embeddings.embed_documents(all_texts)
55
- index = faiss.IndexFlatL2(len(embeddings[0]))
56
- index.add(np.array(embeddings))
57
  gpu_resource = faiss.StandardGpuResources()
58
  gpu_index = faiss.index_cpu_to_gpu(gpu_resource, 0, index)
59
  return gpu_index
@@ -74,6 +61,33 @@ class DocumentRetrievalAndGeneration:
74
  )
75
  return tokenizer, model
76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  def create_retriever_tool(self):
78
  class RetrieverTool(Tool):
79
  name = "retriever"
@@ -92,15 +106,13 @@ class DocumentRetrievalAndGeneration:
92
 
93
  def forward(self, query: str) -> str:
94
  similarityThreshold = 1
95
- query_embedding = self.parent.embeddings.embed_query(query)
96
  distances, indices = self.parent.gpu_index.search(np.array([query_embedding]), k=3)
97
  content = ""
98
- filtered_results = []
99
  for idx, distance in zip(indices[0], distances[0]):
100
  if distance <= similarityThreshold:
101
- filtered_results.append(idx)
102
- content += "-" * 50 + "\n"
103
- content += self.parent.all_splits[idx].page_content + "\n"
104
  return content
105
 
106
  return RetrieverTool(self)
@@ -109,33 +121,6 @@ class DocumentRetrievalAndGeneration:
109
  llm_engine = HfEngine("meta-llama/Meta-Llama-3.1-8B-Instruct")
110
  return ReactJsonAgent(tools=[self.retriever_tool], llm_engine=llm_engine, max_iterations=4, verbose=2)
111
 
112
- def generate_response_with_timeout(self, input_ids, max_new_tokens=1000):
113
- try:
114
- streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
115
- generate_kwargs = dict(
116
- input_ids=input_ids,
117
- max_new_tokens=max_new_tokens,
118
- do_sample=True,
119
- top_p=1.0,
120
- top_k=20,
121
- temperature=0.8,
122
- repetition_penalty=1.2,
123
- eos_token_id=[128001, 128008, 128009],
124
- streamer=streamer,
125
- )
126
-
127
- thread = Thread(target=self.model.generate, kwargs=generate_kwargs)
128
- thread.start()
129
-
130
- generated_text = ""
131
- for new_text in streamer:
132
- generated_text += new_text
133
-
134
- return generated_text
135
- except Exception as e:
136
- logger.error(f"Error in generate_response_with_timeout: {str(e)}")
137
- return "Text generation process encountered an error"
138
-
139
  def run_agentic_rag(self, question: str) -> str:
140
  enhanced_question = f"""Using the information in your knowledge base, accessible with the 'retriever' tool,
141
  give a comprehensive answer to the question below.
@@ -149,33 +134,64 @@ Question:
149
 
150
  return self.agent.run(enhanced_question)
151
 
152
- def run_standard_rag(self, question: str) -> str:
153
- context = self.retriever_tool(query=question)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
 
155
  conversation = [
156
  {"role": "system", "content": "You are a knowledgeable assistant with access to a comprehensive database."},
157
  {"role": "user", "content": f"""
158
  I need you to answer my question and provide related information in a specific format.
159
- I have provided five relatable json files {context}, choose the most suitable chunks for answering the query.
160
  RETURN ONLY SOLUTION without additional comments, sign-offs, retrived chunks, refrence to any Ticket or extra phrases. Be direct and to the point.
161
  IF THERE IS NO ANSWER RELATABLE IN RETRIEVED CHUNKS, RETURN "NO SOLUTION AVAILABLE".
162
  DO NOT GIVE REFRENCE TO ANY CHUNKS OR TICKETS,BE ON POINT.
163
 
164
  Here's my question:
165
- Query: {question}
166
  Solution==>
167
  """}
168
  ]
169
  input_ids = self.tokenizer.apply_chat_template(conversation, return_tensors="pt").to(self.model.device)
170
 
171
- return self.generate_response_with_timeout(input_ids)
 
 
172
 
173
- def query_and_generate_response(self, query):
174
- agentic_answer = self.run_agentic_rag(query)
175
- standard_answer = self.run_standard_rag(query)
176
-
177
- combined_answer = f"Agentic RAG Answer:\n{agentic_answer}\n\nStandard RAG Answer:\n{standard_answer}"
178
- return combined_answer, "" # Return empty string for 'content' as it's not used in this implementation
 
 
 
 
 
 
 
 
 
 
 
179
 
180
  def qa_infer_gradio(self, query):
181
  response = self.query_and_generate_response(query)
@@ -186,47 +202,40 @@ if __name__ == "__main__":
186
  lm_model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
187
  data_folder = 'sample_embedding_folder2'
188
 
189
- # Set your HuggingFace token here
190
- os.environ["HUGGINGFACE_TOKEN"] = "your_huggingface_token_here"
191
 
192
- try:
193
- doc_retrieval_gen = DocumentRetrievalAndGeneration(embedding_model_name, lm_model_id, data_folder)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
 
195
- def launch_interface():
196
- css_code = """
197
- .gradio-container {
198
- background-color: #daccdb;
199
- }
200
- button {
201
- background-color: #927fc7;
202
- color: black;
203
- border: 1px solid black;
204
- padding: 10px;
205
- margin-right: 10px;
206
- font-size: 16px;
207
- font-weight: bold;
208
- }
209
- """
210
- EXAMPLES = [
211
- "On which devices can the VIP and CSI2 modules operate simultaneously?",
212
- "I'm using Code Composer Studio 5.4.0.00091 and enabled FPv4SPD16 floating point support for CortexM4 in TDA2. However, after building the project, the .asm file shows --float_support=vfplib instead of FPv4SPD16. Why is this happening?",
213
- "Could you clarify the maximum number of cameras that can be connected simultaneously to the video input ports on the TDA2x SoC, considering it supports up to 10 multiplexed input ports and includes 3 dedicated video input modules?"
214
- ]
215
-
216
- interface = gr.Interface(
217
- fn=doc_retrieval_gen.qa_infer_gradio,
218
- inputs=[gr.Textbox(label="QUERY", placeholder="Enter your query here")],
219
- allow_flagging='never',
220
- examples=EXAMPLES,
221
- cache_examples=False,
222
- outputs=[gr.Textbox(label="RESPONSE"), gr.Textbox(label="RELATED QUERIES")],
223
- css=css_code,
224
- title="TI E2E FORUM Multi-Agent RAG"
225
- )
226
 
227
- interface.launch(debug=True)
228
 
229
- launch_interface()
230
- except Exception as e:
231
- logger.error(f"An error occurred: {str(e)}")
232
- logger.info("Please check your environment setup and try again.")
 
1
  import os
2
  import multiprocessing
3
  import concurrent.futures
4
+ from langchain.document_loaders import TextLoader, DirectoryLoader
5
  from langchain.text_splitter import RecursiveCharacterTextSplitter
6
+ from langchain.vectorstores import FAISS
7
+ from sentence_transformers import SentenceTransformer
8
+ import faiss
9
+ import torch
10
+ import numpy as np
11
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig
12
  from datetime import datetime
13
  import json
 
15
  import re
16
  from threading import Thread
17
  from transformers.agents import Tool, HfEngine, ReactJsonAgent
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  class DocumentRetrievalAndGeneration:
20
  def __init__(self, embedding_model_name, lm_model_id, data_folder):
21
  self.all_splits = self.load_documents(data_folder)
22
+ self.embeddings = SentenceTransformer(embedding_model_name)
23
  self.gpu_index = self.create_faiss_index()
24
  self.tokenizer, self.model = self.initialize_llm(lm_model_id)
25
  self.retriever_tool = self.create_retriever_tool()
 
30
  documents = loader.load()
31
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=5000, chunk_overlap=250)
32
  all_splits = text_splitter.split_documents(documents)
33
+ print('Length of documents:', len(documents))
34
+ print("LEN of all_splits", len(all_splits))
35
+ for i in range(3):
36
+ print(all_splits[i].page_content)
37
  return all_splits
38
 
39
  def create_faiss_index(self):
40
  all_texts = [split.page_content for split in self.all_splits]
41
+ embeddings = self.embeddings.encode(all_texts, convert_to_tensor=True).cpu().numpy()
42
+ index = faiss.IndexFlatL2(embeddings.shape[1])
43
+ index.add(embeddings)
44
  gpu_resource = faiss.StandardGpuResources()
45
  gpu_index = faiss.index_cpu_to_gpu(gpu_resource, 0, index)
46
  return gpu_index
 
61
  )
62
  return tokenizer, model
63
 
64
+ def generate_response_with_timeout(self, input_ids, max_new_tokens=1000):
65
+ try:
66
+ streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
67
+ generate_kwargs = dict(
68
+ input_ids=input_ids,
69
+ max_new_tokens=max_new_tokens,
70
+ do_sample=True,
71
+ top_p=1.0,
72
+ top_k=20,
73
+ temperature=0.8,
74
+ repetition_penalty=1.2,
75
+ eos_token_id=[128001, 128008, 128009],
76
+ streamer=streamer,
77
+ )
78
+
79
+ thread = Thread(target=self.model.generate, kwargs=generate_kwargs)
80
+ thread.start()
81
+
82
+ generated_text = ""
83
+ for new_text in streamer:
84
+ generated_text += new_text
85
+
86
+ return generated_text
87
+ except Exception as e:
88
+ print(f"Error in generate_response_with_timeout: {str(e)}")
89
+ return "Text generation process encountered an error"
90
+
91
  def create_retriever_tool(self):
92
  class RetrieverTool(Tool):
93
  name = "retriever"
 
106
 
107
  def forward(self, query: str) -> str:
108
  similarityThreshold = 1
109
+ query_embedding = self.parent.embeddings.encode(query, convert_to_tensor=True).cpu().numpy()
110
  distances, indices = self.parent.gpu_index.search(np.array([query_embedding]), k=3)
111
  content = ""
 
112
  for idx, distance in zip(indices[0], distances[0]):
113
  if distance <= similarityThreshold:
114
+ content += "-" * 50 + "\n"
115
+ content += self.parent.all_splits[idx].page_content + "\n"
 
116
  return content
117
 
118
  return RetrieverTool(self)
 
121
  llm_engine = HfEngine("meta-llama/Meta-Llama-3.1-8B-Instruct")
122
  return ReactJsonAgent(tools=[self.retriever_tool], llm_engine=llm_engine, max_iterations=4, verbose=2)
123
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  def run_agentic_rag(self, question: str) -> str:
125
  enhanced_question = f"""Using the information in your knowledge base, accessible with the 'retriever' tool,
126
  give a comprehensive answer to the question below.
 
134
 
135
  return self.agent.run(enhanced_question)
136
 
137
+ def query_and_generate_response(self, query):
138
+ # Standard RAG
139
+ similarityThreshold = 1
140
+ query_embedding = self.embeddings.encode(query, convert_to_tensor=True).cpu().numpy()
141
+ distances, indices = self.gpu_index.search(np.array([query_embedding]), k=3)
142
+ print("Distance", distances, "indices", indices)
143
+ content = ""
144
+ filtered_results = []
145
+ for idx, distance in zip(indices[0], distances[0]):
146
+ if distance <= similarityThreshold:
147
+ filtered_results.append(idx)
148
+ for i in filtered_results:
149
+ print(self.all_splits[i].page_content)
150
+ content += "-" * 50 + "\n"
151
+ content += self.all_splits[idx].page_content + "\n"
152
+ print("CHUNK", idx)
153
+ print("Distance:", distance)
154
+ print("indices:", indices)
155
+ print(self.all_splits[idx].page_content)
156
+ print("############################")
157
 
158
  conversation = [
159
  {"role": "system", "content": "You are a knowledgeable assistant with access to a comprehensive database."},
160
  {"role": "user", "content": f"""
161
  I need you to answer my question and provide related information in a specific format.
162
+ I have provided five relatable json files {content}, choose the most suitable chunks for answering the query.
163
  RETURN ONLY SOLUTION without additional comments, sign-offs, retrived chunks, refrence to any Ticket or extra phrases. Be direct and to the point.
164
  IF THERE IS NO ANSWER RELATABLE IN RETRIEVED CHUNKS, RETURN "NO SOLUTION AVAILABLE".
165
  DO NOT GIVE REFRENCE TO ANY CHUNKS OR TICKETS,BE ON POINT.
166
 
167
  Here's my question:
168
+ Query: {query}
169
  Solution==>
170
  """}
171
  ]
172
  input_ids = self.tokenizer.apply_chat_template(conversation, return_tensors="pt").to(self.model.device)
173
 
174
+ start_time = datetime.now()
175
+ standard_response = self.generate_response_with_timeout(input_ids)
176
+ elapsed_time = datetime.now() - start_time
177
 
178
+ print("Generated standard response:", standard_response)
179
+ print("Time elapsed:", elapsed_time)
180
+ print("Device in use:", self.model.device)
181
+
182
+ standard_solution_text = standard_response.strip()
183
+ if "Solution:" in standard_solution_text:
184
+ standard_solution_text = standard_solution_text.split("Solution:", 1)[1].strip()
185
+
186
+ # Post-processing to remove "assistant" prefix
187
+ standard_solution_text = re.sub(r'^assistant\s*', '', standard_solution_text, flags=re.IGNORECASE)
188
+ standard_solution_text = standard_solution_text.strip()
189
+
190
+ # Agentic RAG
191
+ agentic_solution_text = self.run_agentic_rag(query)
192
+
193
+ combined_solution = f"Standard RAG Solution:\n{standard_solution_text}\n\nAgentic RAG Solution:\n{agentic_solution_text}"
194
+ return combined_solution, content
195
 
196
  def qa_infer_gradio(self, query):
197
  response = self.query_and_generate_response(query)
 
202
  lm_model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
203
  data_folder = 'sample_embedding_folder2'
204
 
205
+ doc_retrieval_gen = DocumentRetrievalAndGeneration(embedding_model_name, lm_model_id, data_folder)
 
206
 
207
+ def launch_interface():
208
+ css_code = """
209
+ .gradio-container {
210
+ background-color: #daccdb;
211
+ }
212
+ button {
213
+ background-color: #927fc7;
214
+ color: black;
215
+ border: 1px solid black;
216
+ padding: 10px;
217
+ margin-right: 10px;
218
+ font-size: 16px;
219
+ font-weight: bold;
220
+ }
221
+ """
222
+ EXAMPLES = [
223
+ "On which devices can the VIP and CSI2 modules operate simultaneously?",
224
+ "I'm using Code Composer Studio 5.4.0.00091 and enabled FPv4SPD16 floating point support for CortexM4 in TDA2. However, after building the project, the .asm file shows --float_support=vfplib instead of FPv4SPD16. Why is this happening?",
225
+ "Could you clarify the maximum number of cameras that can be connected simultaneously to the video input ports on the TDA2x SoC, considering it supports up to 10 multiplexed input ports and includes 3 dedicated video input modules?"
226
+ ]
227
 
228
+ interface = gr.Interface(
229
+ fn=doc_retrieval_gen.qa_infer_gradio,
230
+ inputs=[gr.Textbox(label="QUERY", placeholder="Enter your query here")],
231
+ allow_flagging='never',
232
+ examples=EXAMPLES,
233
+ cache_examples=False,
234
+ outputs=[gr.Textbox(label="RESPONSE"), gr.Textbox(label="RELATED QUERIES")],
235
+ css=css_code,
236
+ title="TI E2E FORUM"
237
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
238
 
239
+ interface.launch(debug=True)
240
 
241
+ launch_interface()