syedmudassir16 commited on
Commit
495b986
1 Parent(s): f932d05

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +74 -50
app.py CHANGED
@@ -14,6 +14,8 @@ from transformers.agents import Tool, HfEngine, ReactJsonAgent
14
  from huggingface_hub import InferenceClient
15
  import logging
16
  import torch
 
 
17
 
18
  logging.basicConfig(level=logging.INFO)
19
  logger = logging.getLogger(__name__)
@@ -29,11 +31,7 @@ class DocumentRetrievalAndGeneration:
29
  def __init__(self, embedding_model_name, lm_model_id, data_folder):
30
  self.all_splits = self.load_documents(data_folder)
31
  self.embeddings = HuggingFaceEmbeddings(model_name=embedding_model_name)
32
- if FAISS is not None:
33
- self.vectordb = self.create_faiss_index()
34
- else:
35
- logger.warning("FAISS is not available. Vector search functionality will be limited.")
36
- self.vectordb = None
37
  self.tokenizer, self.model = self.initialize_llm(lm_model_id)
38
  self.retriever_tool = self.create_retriever_tool()
39
  self.agent = self.create_agent()
@@ -41,17 +39,20 @@ class DocumentRetrievalAndGeneration:
41
  def load_documents(self, folder_path):
42
  loader = DirectoryLoader(folder_path, loader_cls=TextLoader)
43
  documents = loader.load()
44
- text_splitter = RecursiveCharacterTextSplitter(chunk_size=200, chunk_overlap=20)
45
  all_splits = text_splitter.split_documents(documents)
46
  logger.info(f'Loaded {len(documents)} documents')
47
  logger.info(f"Split into {len(all_splits)} chunks")
48
  return all_splits
49
 
50
  def create_faiss_index(self):
51
- if FAISS is None:
52
- logger.error("FAISS is not available. Cannot create index.")
53
- return None
54
- return FAISS.from_documents(self.all_splits, self.embeddings)
 
 
 
55
 
56
  def initialize_llm(self, model_id):
57
  quantization_config = BitsAndBytesConfig(
@@ -81,24 +82,56 @@ class DocumentRetrievalAndGeneration:
81
  }
82
  output_type = "text"
83
 
84
- def __init__(self, vectordb, **kwargs):
85
  super().__init__(**kwargs)
86
- self.vectordb = vectordb
87
 
88
  def forward(self, query: str) -> str:
89
- if self.vectordb is None:
90
- return "Vector database is not available. Cannot perform retrieval."
91
- docs = self.vectordb.similarity_search(query, k=3)
92
- return "\nRetrieved documents:\n" + "".join(
93
- [f"===== Document {str(i)} =====\n" + doc.page_content for i, doc in enumerate(docs)]
94
- )
95
-
96
- return RetrieverTool(self.vectordb)
 
 
 
 
 
97
 
98
  def create_agent(self):
99
  llm_engine = HfEngine("meta-llama/Meta-Llama-3.1-8B-Instruct")
100
  return ReactJsonAgent(tools=[self.retriever_tool], llm_engine=llm_engine, max_iterations=4, verbose=2)
101
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  def run_agentic_rag(self, question: str) -> str:
103
  enhanced_question = f"""Using the information in your knowledge base, accessible with the 'retriever' tool,
104
  give a comprehensive answer to the question below.
@@ -115,20 +148,23 @@ Question:
115
  def run_standard_rag(self, question: str) -> str:
116
  context = self.retriever_tool(query=question)
117
 
118
- prompt = f"""Given the question and supporting documents below, give a comprehensive answer to the question.
119
- Respond only to the question asked, be concise and relevant.
120
- Provide the number of the source document when relevant.
121
-
122
- Question:
123
- {question}
124
-
125
- {context}
126
- """
127
- messages = [{"role": "user", "content": prompt}]
128
-
129
- reader_llm = InferenceClient("meta-llama/Meta-Llama-3.1-8B-Instruct")
130
-
131
- return reader_llm.chat_completion(messages).choices[0].message.content
 
 
 
132
 
133
  def query_and_generate_response(self, query):
134
  agentic_answer = self.run_agentic_rag(query)
@@ -141,29 +177,17 @@ Question:
141
  response = self.query_and_generate_response(query)
142
  return response
143
 
144
- def save_index(self, path):
145
- if self.vectordb is not None:
146
- self.vectordb.save_local(path)
147
- else:
148
- logger.warning("Vector database is not available. Cannot save index.")
149
-
150
- def load_index(self, path):
151
- if FAISS is not None:
152
- self.vectordb = FAISS.load_local(path, self.embeddings)
153
- else:
154
- logger.warning("FAISS is not available. Cannot load index.")
155
-
156
  if __name__ == "__main__":
157
- embedding_model_name = 'thenlper/gte-small'
158
  lm_model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
159
  data_folder = 'sample_embedding_folder2'
160
 
 
 
 
161
  try:
162
  doc_retrieval_gen = DocumentRetrievalAndGeneration(embedding_model_name, lm_model_id, data_folder)
163
 
164
- # Save the index for future use
165
- doc_retrieval_gen.save_index("faiss_index")
166
-
167
  def launch_interface():
168
  css_code = """
169
  .gradio-container {
 
14
  from huggingface_hub import InferenceClient
15
  import logging
16
  import torch
17
+ import numpy as np
18
+ import faiss
19
 
20
  logging.basicConfig(level=logging.INFO)
21
  logger = logging.getLogger(__name__)
 
31
  def __init__(self, embedding_model_name, lm_model_id, data_folder):
32
  self.all_splits = self.load_documents(data_folder)
33
  self.embeddings = HuggingFaceEmbeddings(model_name=embedding_model_name)
34
+ self.gpu_index = self.create_faiss_index()
 
 
 
 
35
  self.tokenizer, self.model = self.initialize_llm(lm_model_id)
36
  self.retriever_tool = self.create_retriever_tool()
37
  self.agent = self.create_agent()
 
39
  def load_documents(self, folder_path):
40
  loader = DirectoryLoader(folder_path, loader_cls=TextLoader)
41
  documents = loader.load()
42
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=5000, chunk_overlap=250)
43
  all_splits = text_splitter.split_documents(documents)
44
  logger.info(f'Loaded {len(documents)} documents')
45
  logger.info(f"Split into {len(all_splits)} chunks")
46
  return all_splits
47
 
48
  def create_faiss_index(self):
49
+ all_texts = [split.page_content for split in self.all_splits]
50
+ embeddings = self.embeddings.embed_documents(all_texts)
51
+ index = faiss.IndexFlatL2(len(embeddings[0]))
52
+ index.add(np.array(embeddings))
53
+ gpu_resource = faiss.StandardGpuResources()
54
+ gpu_index = faiss.index_cpu_to_gpu(gpu_resource, 0, index)
55
+ return gpu_index
56
 
57
  def initialize_llm(self, model_id):
58
  quantization_config = BitsAndBytesConfig(
 
82
  }
83
  output_type = "text"
84
 
85
+ def __init__(self, parent, **kwargs):
86
  super().__init__(**kwargs)
87
+ self.parent = parent
88
 
89
  def forward(self, query: str) -> str:
90
+ similarityThreshold = 1
91
+ query_embedding = self.parent.embeddings.embed_query(query)
92
+ distances, indices = self.parent.gpu_index.search(np.array([query_embedding]), k=3)
93
+ content = ""
94
+ filtered_results = []
95
+ for idx, distance in zip(indices[0], distances[0]):
96
+ if distance <= similarityThreshold:
97
+ filtered_results.append(idx)
98
+ content += "-" * 50 + "\n"
99
+ content += self.parent.all_splits[idx].page_content + "\n"
100
+ return content
101
+
102
+ return RetrieverTool(self)
103
 
104
  def create_agent(self):
105
  llm_engine = HfEngine("meta-llama/Meta-Llama-3.1-8B-Instruct")
106
  return ReactJsonAgent(tools=[self.retriever_tool], llm_engine=llm_engine, max_iterations=4, verbose=2)
107
 
108
+ def generate_response_with_timeout(self, input_ids, max_new_tokens=1000):
109
+ try:
110
+ streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
111
+ generate_kwargs = dict(
112
+ input_ids=input_ids,
113
+ max_new_tokens=max_new_tokens,
114
+ do_sample=True,
115
+ top_p=1.0,
116
+ top_k=20,
117
+ temperature=0.8,
118
+ repetition_penalty=1.2,
119
+ eos_token_id=[128001, 128008, 128009],
120
+ streamer=streamer,
121
+ )
122
+
123
+ thread = Thread(target=self.model.generate, kwargs=generate_kwargs)
124
+ thread.start()
125
+
126
+ generated_text = ""
127
+ for new_text in streamer:
128
+ generated_text += new_text
129
+
130
+ return generated_text
131
+ except Exception as e:
132
+ logger.error(f"Error in generate_response_with_timeout: {str(e)}")
133
+ return "Text generation process encountered an error"
134
+
135
  def run_agentic_rag(self, question: str) -> str:
136
  enhanced_question = f"""Using the information in your knowledge base, accessible with the 'retriever' tool,
137
  give a comprehensive answer to the question below.
 
148
  def run_standard_rag(self, question: str) -> str:
149
  context = self.retriever_tool(query=question)
150
 
151
+ conversation = [
152
+ {"role": "system", "content": "You are a knowledgeable assistant with access to a comprehensive database."},
153
+ {"role": "user", "content": f"""
154
+ I need you to answer my question and provide related information in a specific format.
155
+ I have provided five relatable json files {context}, choose the most suitable chunks for answering the query.
156
+ RETURN ONLY SOLUTION without additional comments, sign-offs, retrived chunks, refrence to any Ticket or extra phrases. Be direct and to the point.
157
+ IF THERE IS NO ANSWER RELATABLE IN RETRIEVED CHUNKS, RETURN "NO SOLUTION AVAILABLE".
158
+ DO NOT GIVE REFRENCE TO ANY CHUNKS OR TICKETS,BE ON POINT.
159
+
160
+ Here's my question:
161
+ Query: {question}
162
+ Solution==>
163
+ """}
164
+ ]
165
+ input_ids = self.tokenizer.apply_chat_template(conversation, return_tensors="pt").to(self.model.device)
166
+
167
+ return self.generate_response_with_timeout(input_ids)
168
 
169
  def query_and_generate_response(self, query):
170
  agentic_answer = self.run_agentic_rag(query)
 
177
  response = self.query_and_generate_response(query)
178
  return response
179
 
 
 
 
 
 
 
 
 
 
 
 
 
180
  if __name__ == "__main__":
181
+ embedding_model_name = 'flax-sentence-embeddings/all_datasets_v3_MiniLM-L12'
182
  lm_model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
183
  data_folder = 'sample_embedding_folder2'
184
 
185
+ # Set your HuggingFace token here
186
+ os.environ["HUGGINGFACE_TOKEN"] = "your_huggingface_token_here"
187
+
188
  try:
189
  doc_retrieval_gen = DocumentRetrievalAndGeneration(embedding_model_name, lm_model_id, data_folder)
190
 
 
 
 
191
  def launch_interface():
192
  css_code = """
193
  .gradio-container {