nileshhanotia commited on
Commit
8cf01c9
·
verified ·
1 Parent(s): 56abc73

Update rag_system.py

Browse files
Files changed (1) hide show
  1. rag_system.py +16 -44
rag_system.py CHANGED
@@ -6,41 +6,42 @@ from langchain.text_splitter import CharacterTextSplitter
6
  from langchain.docstore.document import Document
7
  from transformers import pipeline
8
  from langchain.prompts import PromptTemplate
9
- from typing import List, Dict, Any, Optional
10
 
11
  class RAGSystem:
12
- def __init__(self, sql_generator: SQLGenerator, csv_path: str = "apparel.csv"):
13
- self.sql_generator = sql_generator
14
  self.setup_system(csv_path)
15
  self.qa_pipeline = pipeline("question-answering", model="distilbert-base-cased-distilled-squad")
16
 
17
- def setup_system(self, csv_path: str):
18
  if not os.path.exists(csv_path):
19
  raise FileNotFoundError(f"CSV file not found at {csv_path}")
20
 
 
21
  documents = pd.read_csv(csv_path)
22
 
 
23
  docs = [
24
  Document(
25
- page_content=str(row['Title']),
26
  metadata={'index': idx}
27
  )
28
  for idx, row in documents.iterrows()
29
  ]
30
 
 
31
  text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
32
  split_docs = text_splitter.split_documents(docs)
33
 
 
34
  embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
35
  self.vector_store = FAISS.from_documents(split_docs, embeddings)
36
  self.retriever = self.vector_store.as_retriever()
37
 
38
- def process_query(self, query: str, execute_sql: bool = True) -> Dict[str, Any]:
39
- """
40
- Process a query through both RAG and SQL if needed
41
- """
42
- # Get relevant documents
43
- retrieved_docs = self.retriever.get_relevant_documents(query)
44
  retrieved_text = "\n".join([doc.page_content for doc in retrieved_docs])[:1000]
45
 
46
  # Process with QA pipeline
@@ -48,42 +49,13 @@ class RAGSystem:
48
  "question": query,
49
  "context": retrieved_text
50
  }
51
- qa_response = self.qa_pipeline(qa_input)
52
 
53
- result = {
54
- "qa_answer": qa_response['answer'],
55
- "relevant_docs": [doc.page_content for doc in retrieved_docs[:3]],
56
- "sql_results": None
57
- }
58
-
59
- # If SQL execution is requested and SQL is detected in the query
60
- if execute_sql and "SELECT" in query.upper():
61
- if self.sql_generator.validate_query(query):
62
- sql_results = self.sql_generator.execute_query(query)
63
- result["sql_results"] = sql_results
64
-
65
- return result
66
 
67
- def get_similar_documents(self, query: str, k: int = 5) -> List[Dict[str, Any]]:
68
  """
69
  Retrieve similar documents without processing through QA pipeline
70
  """
71
  docs = self.retriever.get_relevant_documents(query)
72
- return [{'content': doc.page_content, 'metadata': doc.metadata} for doc in docs[:k]]
73
-
74
- # Example usage
75
- if __name__ == "__main__":
76
- # Initialize the SQL generator
77
- sql_gen = SQLGenerator("shopify.db")
78
-
79
- # Initialize the RAG system with the SQL generator
80
- rag = RAGSystem(sql_gen, "apparel.csv")
81
-
82
- # Example query that might include SQL
83
- query = "SELECT * FROM products LIMIT 5"
84
- results = rag.process_query(query)
85
-
86
- # Access different parts of the results
87
- print("QA Answer:", results["qa_answer"])
88
- print("Relevant Documents:", results["relevant_docs"])
89
- print("SQL Results:", results["sql_results"])
 
6
  from langchain.docstore.document import Document
7
  from transformers import pipeline
8
  from langchain.prompts import PromptTemplate
 
9
 
10
  class RAGSystem:
11
+ def __init__(self, csv_path="apparel.csv"):
 
12
  self.setup_system(csv_path)
13
  self.qa_pipeline = pipeline("question-answering", model="distilbert-base-cased-distilled-squad")
14
 
15
+ def setup_system(self, csv_path):
16
  if not os.path.exists(csv_path):
17
  raise FileNotFoundError(f"CSV file not found at {csv_path}")
18
 
19
+ # Read the CSV file
20
  documents = pd.read_csv(csv_path)
21
 
22
+ # Create proper Document objects
23
  docs = [
24
  Document(
25
+ page_content=str(row['Title']), # Convert to string to ensure compatibility
26
  metadata={'index': idx}
27
  )
28
  for idx, row in documents.iterrows()
29
  ]
30
 
31
+ # Split documents
32
  text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
33
  split_docs = text_splitter.split_documents(docs)
34
 
35
+ # Create embeddings and vector store
36
  embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
37
  self.vector_store = FAISS.from_documents(split_docs, embeddings)
38
  self.retriever = self.vector_store.as_retriever()
39
 
40
+ def process_query(self, query):
41
+ # Retrieve documents based on the query
42
+ retrieved_docs = self.retriever.get_relevant_documents(query) # Changed from invoke to get_relevant_documents
43
+
44
+ # Properly access page_content from Document objects
 
45
  retrieved_text = "\n".join([doc.page_content for doc in retrieved_docs])[:1000]
46
 
47
  # Process with QA pipeline
 
49
  "question": query,
50
  "context": retrieved_text
51
  }
52
+ response = self.qa_pipeline(qa_input)
53
 
54
+ return response['answer']
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
+ def get_similar_documents(self, query, k=5):
57
  """
58
  Retrieve similar documents without processing through QA pipeline
59
  """
60
  docs = self.retriever.get_relevant_documents(query)
61
+ return [{'content': doc.page_content, 'metadata': doc.metadata} for doc in docs[:k]]