syedmudassir16 commited on
Commit
b2ba33f
1 Parent(s): d734b57

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +97 -77
app.py CHANGED
@@ -1,4 +1,6 @@
1
  import os
 
 
2
  from langchain.document_loaders import TextLoader, DirectoryLoader
3
  from langchain.text_splitter import RecursiveCharacterTextSplitter
4
  from langchain.vectorstores import FAISS
@@ -8,61 +10,44 @@ import torch
8
  import numpy as np
9
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig
10
  from datetime import datetime
 
11
  import gradio as gr
12
  import re
13
  from threading import Thread
 
 
 
14
 
15
- class MultiDocumentAgentSystem:
16
- def __init__(self, documents_dict, model, tokenizer, embeddings):
17
- self.model = model
18
- self.tokenizer = tokenizer
19
- self.embeddings = embeddings
20
- self.document_vectors = self.create_document_vectors(documents_dict)
21
-
22
- def create_document_vectors(self, documents_dict):
23
- document_vectors = {}
24
- for doc_name, content in documents_dict.items():
25
- vectors = self.embeddings.encode(content, convert_to_tensor=True)
26
- document_vectors[doc_name] = vectors
27
- return document_vectors
28
-
29
- def query(self, user_input):
30
- query_vector = self.embeddings.encode(user_input, convert_to_tensor=True)
31
-
32
- # Find the most similar document
33
- most_similar_doc = max(self.document_vectors.items(),
34
- key=lambda x: torch.cosine_similarity(query_vector, x[1], dim=0))
35
-
36
- # Generate response using the most similar document as context
37
- response = self.generate_response(user_input, most_similar_doc[0], most_similar_doc[1])
38
- return response
39
-
40
- def generate_response(self, query, doc_name, doc_vector):
41
- prompt = f"Based on the document '{doc_name}', answer the following question: {query}"
42
- input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.model.device)
43
-
44
- with torch.no_grad():
45
- output = self.model.generate(input_ids, max_length=150, num_return_sequences=1)
46
-
47
- response = self.tokenizer.decode(output[0], skip_special_tokens=True)
48
- return response
49
 
50
  class DocumentRetrievalAndGeneration:
51
  def __init__(self, embedding_model_name, lm_model_id, data_folder):
52
- self.documents_dict = self.load_documents(data_folder)
53
  self.embeddings = SentenceTransformer(embedding_model_name)
 
54
  self.tokenizer, self.model = self.initialize_llm(lm_model_id)
55
- self.multi_doc_system = MultiDocumentAgentSystem(self.documents_dict, self.model, self.tokenizer, self.embeddings)
 
56
 
57
  def load_documents(self, folder_path):
58
- documents_dict = {}
59
- for file_name in os.listdir(folder_path):
60
- if file_name.endswith('.txt'):
61
- file_path = os.path.join(folder_path, file_name)
62
- with open(file_path, 'r', encoding='utf-8') as file:
63
- content = file.read()
64
- documents_dict[file_name[:-4]] = content
65
- return documents_dict
 
 
 
 
 
 
 
 
 
66
 
67
  def initialize_llm(self, model_id):
68
  quantization_config = BitsAndBytesConfig(
@@ -80,44 +65,79 @@ class DocumentRetrievalAndGeneration:
80
  )
81
  return tokenizer, model
82
 
83
- def generate_response_with_timeout(self, input_ids, max_new_tokens=1000):
84
- try:
85
- streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
86
- generate_kwargs = dict(
87
- input_ids=input_ids,
88
- max_new_tokens=max_new_tokens,
89
- do_sample=True,
90
- top_p=1.0,
91
- top_k=20,
92
- temperature=0.8,
93
- repetition_penalty=1.2,
94
- eos_token_id=self.tokenizer.eos_token_id,
95
- streamer=streamer,
96
- )
97
-
98
- thread = Thread(target=self.model.generate, kwargs=generate_kwargs)
99
- thread.start()
100
-
101
- generated_text = ""
102
- for new_text in streamer:
103
- generated_text += new_text
104
-
105
- return generated_text
106
- except Exception as e:
107
- print(f"Error in generate_response_with_timeout: {str(e)}")
108
- return "Text generation process encountered an error"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
 
110
  def query_and_generate_response(self, query):
111
- response = self.multi_doc_system.query(query)
112
- return str(response), ""
 
 
 
113
 
114
  def qa_infer_gradio(self, query):
115
- response, related_queries = self.query_and_generate_response(query)
116
- return response, related_queries
117
 
118
  if __name__ == "__main__":
119
- embedding_model_name = 'sentence-transformers/all-MiniLM-L6-v2'
120
- lm_model_id = "facebook/opt-350m" # You can change this to a different open-source model
121
  data_folder = 'sample_embedding_folder2'
122
 
123
  doc_retrieval_gen = DocumentRetrievalAndGeneration(embedding_model_name, lm_model_id, data_folder)
@@ -151,7 +171,7 @@ if __name__ == "__main__":
151
  cache_examples=False,
152
  outputs=[gr.Textbox(label="RESPONSE"), gr.Textbox(label="RELATED QUERIES")],
153
  css=css_code,
154
- title="TI E2E FORUM"
155
  )
156
 
157
  interface.launch(debug=True)
 
1
  import os
2
+ import multiprocessing
3
+ import concurrent.futures
4
  from langchain.document_loaders import TextLoader, DirectoryLoader
5
  from langchain.text_splitter import RecursiveCharacterTextSplitter
6
  from langchain.vectorstores import FAISS
 
10
  import numpy as np
11
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig
12
  from datetime import datetime
13
+ import json
14
  import gradio as gr
15
  import re
16
  from threading import Thread
17
+ from transformers.agents import Tool, HfEngine, ReactJsonAgent
18
+ from huggingface_hub import InferenceClient
19
+ import logging
20
 
21
+ logging.basicConfig(level=logging.INFO)
22
+ logger = logging.getLogger(__name__)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
  class DocumentRetrievalAndGeneration:
25
  def __init__(self, embedding_model_name, lm_model_id, data_folder):
26
+ self.all_splits = self.load_documents(data_folder)
27
  self.embeddings = SentenceTransformer(embedding_model_name)
28
+ self.vectordb = self.create_faiss_index()
29
  self.tokenizer, self.model = self.initialize_llm(lm_model_id)
30
+ self.retriever_tool = self.create_retriever_tool()
31
+ self.agent = self.create_agent()
32
 
33
  def load_documents(self, folder_path):
34
+ loader = DirectoryLoader(folder_path, loader_cls=TextLoader)
35
+ documents = loader.load()
36
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=200, chunk_overlap=20)
37
+ all_splits = text_splitter.split_documents(documents)
38
+ logger.info(f'Loaded {len(documents)} documents')
39
+ logger.info(f"Split into {len(all_splits)} chunks")
40
+ return all_splits
41
+
42
+ def create_faiss_index(self):
43
+ all_texts = [split.page_content for split in self.all_splits]
44
+ embeddings = self.embeddings.encode(all_texts, convert_to_tensor=True).cpu().numpy()
45
+ vectordb = FAISS.from_embeddings(
46
+ embeddings,
47
+ self.embeddings,
48
+ metadatas=[{"source": f"doc_{i}"} for i in range(len(all_texts))]
49
+ )
50
+ return vectordb
51
 
52
  def initialize_llm(self, model_id):
53
  quantization_config = BitsAndBytesConfig(
 
65
  )
66
  return tokenizer, model
67
 
68
+ def create_retriever_tool(self):
69
+ class RetrieverTool(Tool):
70
+ name = "retriever"
71
+ description = "Retrieves documents from the knowledge base that are semantically similar to the input query."
72
+ inputs = {
73
+ "query": {
74
+ "type": "text",
75
+ "description": "The query to perform. Use affirmative form rather than a question.",
76
+ }
77
+ }
78
+ output_type = "text"
79
+
80
+ def __init__(self, vectordb, **kwargs):
81
+ super().__init__(**kwargs)
82
+ self.vectordb = vectordb
83
+
84
+ def forward(self, query: str) -> str:
85
+ docs = self.vectordb.similarity_search(query, k=3)
86
+ return "\nRetrieved documents:\n" + "".join(
87
+ [f"===== Document {str(i)} =====\n" + doc.page_content for i, doc in enumerate(docs)]
88
+ )
89
+
90
+ return RetrieverTool(self.vectordb)
91
+
92
+ def create_agent(self):
93
+ llm_engine = HfEngine("meta-llama/Meta-Llama-3.1-8B-Instruct")
94
+ return ReactJsonAgent(tools=[self.retriever_tool], llm_engine=llm_engine, max_iterations=4, verbose=2)
95
+
96
+ def run_agentic_rag(self, question: str) -> str:
97
+ enhanced_question = f"""Using the information in your knowledge base, accessible with the 'retriever' tool,
98
+ give a comprehensive answer to the question below.
99
+ Respond only to the question asked, be concise and relevant.
100
+ If you can't find information, try calling your retriever again with different arguments.
101
+ Make sure to cover the question completely by calling the retriever tool several times with semantically different queries.
102
+ Your queries should be in affirmative form, not questions.
103
+
104
+ Question:
105
+ {question}"""
106
+
107
+ return self.agent.run(enhanced_question)
108
+
109
+ def run_standard_rag(self, question: str) -> str:
110
+ context = self.retriever_tool(query=question)
111
+
112
+ prompt = f"""Given the question and supporting documents below, give a comprehensive answer to the question.
113
+ Respond only to the question asked, be concise and relevant.
114
+ Provide the number of the source document when relevant.
115
+
116
+ Question:
117
+ {question}
118
+
119
+ {context}
120
+ """
121
+ messages = [{"role": "user", "content": prompt}]
122
+
123
+ reader_llm = InferenceClient("meta-llama/Meta-Llama-3.1-8B-Instruct")
124
+
125
+ return reader_llm.chat_completion(messages).choices[0].message.content
126
 
127
  def query_and_generate_response(self, query):
128
+ agentic_answer = self.run_agentic_rag(query)
129
+ standard_answer = self.run_standard_rag(query)
130
+
131
+ combined_answer = f"Agentic RAG Answer:\n{agentic_answer}\n\nStandard RAG Answer:\n{standard_answer}"
132
+ return combined_answer, "" # Return empty string for 'content' as it's not used in this implementation
133
 
134
  def qa_infer_gradio(self, query):
135
+ response = self.query_and_generate_response(query)
136
+ return response
137
 
138
  if __name__ == "__main__":
139
+ embedding_model_name = 'thenlper/gte-small'
140
+ lm_model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
141
  data_folder = 'sample_embedding_folder2'
142
 
143
  doc_retrieval_gen = DocumentRetrievalAndGeneration(embedding_model_name, lm_model_id, data_folder)
 
171
  cache_examples=False,
172
  outputs=[gr.Textbox(label="RESPONSE"), gr.Textbox(label="RELATED QUERIES")],
173
  css=css_code,
174
+ title="TI E2E FORUM Multi-Agent RAG"
175
  )
176
 
177
  interface.launch(debug=True)