syedmudassir16 commited on
Commit
e085441
1 Parent(s): 3cba93e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -28
app.py CHANGED
@@ -14,7 +14,6 @@ import json
14
  import gradio as gr
15
  import re
16
  from threading import Thread
17
- from transformers.agents import Tool, HfEngine, ReactJsonAgent
18
 
19
  class DocumentRetrievalAndGeneration:
20
  def __init__(self, embedding_model_name, lm_model_id, data_folder):
@@ -23,7 +22,6 @@ class DocumentRetrievalAndGeneration:
23
  self.gpu_index = self.create_faiss_index()
24
  self.tokenizer, self.model = self.initialize_llm(lm_model_id)
25
  self.retriever_tool = self.create_retriever_tool()
26
- self.agent = self.create_agent()
27
 
28
  def load_documents(self, folder_path):
29
  loader = DirectoryLoader(folder_path, loader_cls=TextLoader)
@@ -89,22 +87,11 @@ class DocumentRetrievalAndGeneration:
89
  return "Text generation process encountered an error"
90
 
91
  def create_retriever_tool(self):
92
- class RetrieverTool(Tool):
93
- name = "retriever"
94
- description = "Retrieves documents from the knowledge base that are semantically similar to the input query."
95
- inputs = {
96
- "query": {
97
- "type": "text",
98
- "description": "The query to perform. Use affirmative form rather than a question.",
99
- }
100
- }
101
- output_type = "text"
102
-
103
- def __init__(self, parent, **kwargs):
104
- super().__init__(**kwargs)
105
  self.parent = parent
106
 
107
- def forward(self, query: str) -> str:
108
  similarityThreshold = 1
109
  query_embedding = self.parent.embeddings.encode(query, convert_to_tensor=True).cpu().numpy()
110
  distances, indices = self.parent.gpu_index.search(np.array([query_embedding]), k=3)
@@ -117,22 +104,23 @@ class DocumentRetrievalAndGeneration:
117
 
118
  return RetrieverTool(self)
119
 
120
- def create_agent(self):
121
- llm_engine = HfEngine("meta-llama/Meta-Llama-3.1-8B-Instruct")
122
- return ReactJsonAgent(tools=[self.retriever_tool], llm_engine=llm_engine, max_iterations=4, verbose=2)
123
-
124
  def run_agentic_rag(self, question: str) -> str:
125
- enhanced_question = f"""Using the information in your knowledge base, accessible with the 'retriever' tool,
126
- give a comprehensive answer to the question below.
 
 
 
 
 
127
  Respond only to the question asked, be concise and relevant.
128
- If you can't find information, try calling your retriever again with different arguments.
129
- Make sure to cover the question completely by calling the retriever tool several times with semantically different queries.
130
- Your queries should be in affirmative form, not questions.
131
 
132
- Question:
133
- {question}"""
134
 
135
- return self.agent.run(enhanced_question)
 
 
136
 
137
  def query_and_generate_response(self, query):
138
  # Standard RAG
 
14
  import gradio as gr
15
  import re
16
  from threading import Thread
 
17
 
18
  class DocumentRetrievalAndGeneration:
19
  def __init__(self, embedding_model_name, lm_model_id, data_folder):
 
22
  self.gpu_index = self.create_faiss_index()
23
  self.tokenizer, self.model = self.initialize_llm(lm_model_id)
24
  self.retriever_tool = self.create_retriever_tool()
 
25
 
26
  def load_documents(self, folder_path):
27
  loader = DirectoryLoader(folder_path, loader_cls=TextLoader)
 
87
  return "Text generation process encountered an error"
88
 
89
  def create_retriever_tool(self):
90
+ class RetrieverTool:
91
+ def __init__(self, parent):
 
 
 
 
 
 
 
 
 
 
 
92
  self.parent = parent
93
 
94
+ def run(self, query: str) -> str:
95
  similarityThreshold = 1
96
  query_embedding = self.parent.embeddings.encode(query, convert_to_tensor=True).cpu().numpy()
97
  distances, indices = self.parent.gpu_index.search(np.array([query_embedding]), k=3)
 
104
 
105
  return RetrieverTool(self)
106
 
 
 
 
 
107
  def run_agentic_rag(self, question: str) -> str:
108
+ retriever_output = self.retriever_tool.run(question)
109
+
110
+ enhanced_prompt = f"""Using the following information retrieved from the knowledge base:
111
+
112
+ {retriever_output}
113
+
114
+ Give a comprehensive answer to the question below.
115
  Respond only to the question asked, be concise and relevant.
116
+ If you can't find information, say "No relevant information found."
 
 
117
 
118
+ Question: {question}
119
+ Answer:"""
120
 
121
+ input_ids = self.tokenizer.encode(enhanced_prompt, return_tensors="pt").to(self.model.device)
122
+
123
+ return self.generate_response_with_timeout(input_ids)
124
 
125
  def query_and_generate_response(self, query):
126
  # Standard RAG