syedmudassir16 commited on
Commit
f932d05
1 Parent(s): 75b7c0b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -42
app.py CHANGED
@@ -3,7 +3,6 @@ import multiprocessing
3
  import concurrent.futures
4
  from langchain.document_loaders import TextLoader, DirectoryLoader
5
  from langchain.text_splitter import RecursiveCharacterTextSplitter
6
- from langchain_community.vectorstores import FAISS
7
  from langchain_community.embeddings import HuggingFaceEmbeddings
8
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig
9
  from datetime import datetime
@@ -14,15 +13,27 @@ from threading import Thread
14
  from transformers.agents import Tool, HfEngine, ReactJsonAgent
15
  from huggingface_hub import InferenceClient
16
  import logging
 
17
 
18
  logging.basicConfig(level=logging.INFO)
19
  logger = logging.getLogger(__name__)
20
 
 
 
 
 
 
 
 
21
  class DocumentRetrievalAndGeneration:
22
  def __init__(self, embedding_model_name, lm_model_id, data_folder):
23
  self.all_splits = self.load_documents(data_folder)
24
  self.embeddings = HuggingFaceEmbeddings(model_name=embedding_model_name)
25
- self.vectordb = self.create_faiss_index()
 
 
 
 
26
  self.tokenizer, self.model = self.initialize_llm(lm_model_id)
27
  self.retriever_tool = self.create_retriever_tool()
28
  self.agent = self.create_agent()
@@ -37,6 +48,9 @@ class DocumentRetrievalAndGeneration:
37
  return all_splits
38
 
39
  def create_faiss_index(self):
 
 
 
40
  return FAISS.from_documents(self.all_splits, self.embeddings)
41
 
42
  def initialize_llm(self, model_id):
@@ -72,6 +86,8 @@ class DocumentRetrievalAndGeneration:
72
  self.vectordb = vectordb
73
 
74
  def forward(self, query: str) -> str:
 
 
75
  docs = self.vectordb.similarity_search(query, k=3)
76
  return "\nRetrieved documents:\n" + "".join(
77
  [f"===== Document {str(i)} =====\n" + doc.page_content for i, doc in enumerate(docs)]
@@ -126,53 +142,63 @@ Question:
126
  return response
127
 
128
  def save_index(self, path):
129
- self.vectordb.save_local(path)
 
 
 
130
 
131
  def load_index(self, path):
132
- self.vectordb = FAISS.load_local(path, self.embeddings)
 
 
 
133
 
134
  if __name__ == "__main__":
135
  embedding_model_name = 'thenlper/gte-small'
136
  lm_model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
137
  data_folder = 'sample_embedding_folder2'
138
 
139
- doc_retrieval_gen = DocumentRetrievalAndGeneration(embedding_model_name, lm_model_id, data_folder)
140
-
141
- # Save the index for future use
142
- doc_retrieval_gen.save_index("faiss_index")
143
-
144
- def launch_interface():
145
- css_code = """
146
- .gradio-container {
147
- background-color: #daccdb;
148
- }
149
- button {
150
- background-color: #927fc7;
151
- color: black;
152
- border: 1px solid black;
153
- padding: 10px;
154
- margin-right: 10px;
155
- font-size: 16px;
156
- font-weight: bold;
157
- }
158
- """
159
- EXAMPLES = [
160
- "On which devices can the VIP and CSI2 modules operate simultaneously?",
161
- "I'm using Code Composer Studio 5.4.0.00091 and enabled FPv4SPD16 floating point support for CortexM4 in TDA2. However, after building the project, the .asm file shows --float_support=vfplib instead of FPv4SPD16. Why is this happening?",
162
- "Could you clarify the maximum number of cameras that can be connected simultaneously to the video input ports on the TDA2x SoC, considering it supports up to 10 multiplexed input ports and includes 3 dedicated video input modules?"
163
- ]
164
-
165
- interface = gr.Interface(
166
- fn=doc_retrieval_gen.qa_infer_gradio,
167
- inputs=[gr.Textbox(label="QUERY", placeholder="Enter your query here")],
168
- allow_flagging='never',
169
- examples=EXAMPLES,
170
- cache_examples=False,
171
- outputs=[gr.Textbox(label="RESPONSE"), gr.Textbox(label="RELATED QUERIES")],
172
- css=css_code,
173
- title="TI E2E FORUM Multi-Agent RAG"
174
- )
175
 
176
- interface.launch(debug=True)
 
177
 
178
- launch_interface()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  import concurrent.futures
4
  from langchain.document_loaders import TextLoader, DirectoryLoader
5
  from langchain.text_splitter import RecursiveCharacterTextSplitter
 
6
  from langchain_community.embeddings import HuggingFaceEmbeddings
7
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig
8
  from datetime import datetime
 
13
  from transformers.agents import Tool, HfEngine, ReactJsonAgent
14
  from huggingface_hub import InferenceClient
15
  import logging
16
+ import torch
17
 
18
  logging.basicConfig(level=logging.INFO)
19
  logger = logging.getLogger(__name__)
20
 
21
+ try:
22
+ from langchain_community.vectorstores import FAISS
23
+ except ImportError:
24
+ logger.error("Failed to import FAISS. Make sure it's installed correctly.")
25
+ logger.info("You can try: pip install faiss-cpu --no-cache")
26
+ FAISS = None
27
+
28
  class DocumentRetrievalAndGeneration:
29
  def __init__(self, embedding_model_name, lm_model_id, data_folder):
30
  self.all_splits = self.load_documents(data_folder)
31
  self.embeddings = HuggingFaceEmbeddings(model_name=embedding_model_name)
32
+ if FAISS is not None:
33
+ self.vectordb = self.create_faiss_index()
34
+ else:
35
+ logger.warning("FAISS is not available. Vector search functionality will be limited.")
36
+ self.vectordb = None
37
  self.tokenizer, self.model = self.initialize_llm(lm_model_id)
38
  self.retriever_tool = self.create_retriever_tool()
39
  self.agent = self.create_agent()
 
48
  return all_splits
49
 
50
  def create_faiss_index(self):
51
+ if FAISS is None:
52
+ logger.error("FAISS is not available. Cannot create index.")
53
+ return None
54
  return FAISS.from_documents(self.all_splits, self.embeddings)
55
 
56
  def initialize_llm(self, model_id):
 
86
  self.vectordb = vectordb
87
 
88
  def forward(self, query: str) -> str:
89
+ if self.vectordb is None:
90
+ return "Vector database is not available. Cannot perform retrieval."
91
  docs = self.vectordb.similarity_search(query, k=3)
92
  return "\nRetrieved documents:\n" + "".join(
93
  [f"===== Document {str(i)} =====\n" + doc.page_content for i, doc in enumerate(docs)]
 
142
  return response
143
 
144
  def save_index(self, path):
145
+ if self.vectordb is not None:
146
+ self.vectordb.save_local(path)
147
+ else:
148
+ logger.warning("Vector database is not available. Cannot save index.")
149
 
150
  def load_index(self, path):
151
+ if FAISS is not None:
152
+ self.vectordb = FAISS.load_local(path, self.embeddings)
153
+ else:
154
+ logger.warning("FAISS is not available. Cannot load index.")
155
 
156
  if __name__ == "__main__":
157
  embedding_model_name = 'thenlper/gte-small'
158
  lm_model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
159
  data_folder = 'sample_embedding_folder2'
160
 
161
+ try:
162
+ doc_retrieval_gen = DocumentRetrievalAndGeneration(embedding_model_name, lm_model_id, data_folder)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
 
164
+ # Save the index for future use
165
+ doc_retrieval_gen.save_index("faiss_index")
166
 
167
+ def launch_interface():
168
+ css_code = """
169
+ .gradio-container {
170
+ background-color: #daccdb;
171
+ }
172
+ button {
173
+ background-color: #927fc7;
174
+ color: black;
175
+ border: 1px solid black;
176
+ padding: 10px;
177
+ margin-right: 10px;
178
+ font-size: 16px;
179
+ font-weight: bold;
180
+ }
181
+ """
182
+ EXAMPLES = [
183
+ "On which devices can the VIP and CSI2 modules operate simultaneously?",
184
+ "I'm using Code Composer Studio 5.4.0.00091 and enabled FPv4SPD16 floating point support for CortexM4 in TDA2. However, after building the project, the .asm file shows --float_support=vfplib instead of FPv4SPD16. Why is this happening?",
185
+ "Could you clarify the maximum number of cameras that can be connected simultaneously to the video input ports on the TDA2x SoC, considering it supports up to 10 multiplexed input ports and includes 3 dedicated video input modules?"
186
+ ]
187
+
188
+ interface = gr.Interface(
189
+ fn=doc_retrieval_gen.qa_infer_gradio,
190
+ inputs=[gr.Textbox(label="QUERY", placeholder="Enter your query here")],
191
+ allow_flagging='never',
192
+ examples=EXAMPLES,
193
+ cache_examples=False,
194
+ outputs=[gr.Textbox(label="RESPONSE"), gr.Textbox(label="RELATED QUERIES")],
195
+ css=css_code,
196
+ title="TI E2E FORUM Multi-Agent RAG"
197
+ )
198
+
199
+ interface.launch(debug=True)
200
+
201
+ launch_interface()
202
+ except Exception as e:
203
+ logger.error(f"An error occurred: {str(e)}")
204
+ logger.info("Please check your environment setup and try again.")