nileshhanotia commited on
Commit
6a74563
·
verified ·
1 Parent(s): 9acae5c

Update rag_system.py

Browse files
Files changed (1) hide show
  1. rag_system.py +44 -16
rag_system.py CHANGED
@@ -6,42 +6,41 @@ from langchain.text_splitter import CharacterTextSplitter
6
  from langchain.docstore.document import Document
7
  from transformers import pipeline
8
  from langchain.prompts import PromptTemplate
 
9
 
10
  class RAGSystem:
11
- def __init__(self, csv_path="apparel.csv"):
 
12
  self.setup_system(csv_path)
13
  self.qa_pipeline = pipeline("question-answering", model="distilbert-base-cased-distilled-squad")
14
 
15
- def setup_system(self, csv_path):
16
  if not os.path.exists(csv_path):
17
  raise FileNotFoundError(f"CSV file not found at {csv_path}")
18
 
19
- # Read the CSV file
20
  documents = pd.read_csv(csv_path)
21
 
22
- # Create proper Document objects
23
  docs = [
24
  Document(
25
- page_content=str(row['Title']), # Convert to string to ensure compatibility
26
  metadata={'index': idx}
27
  )
28
  for idx, row in documents.iterrows()
29
  ]
30
 
31
- # Split documents
32
  text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
33
  split_docs = text_splitter.split_documents(docs)
34
 
35
- # Create embeddings and vector store
36
  embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
37
  self.vector_store = FAISS.from_documents(split_docs, embeddings)
38
  self.retriever = self.vector_store.as_retriever()
39
 
40
- def process_query(self, query):
41
- # Retrieve documents based on the query
42
- retrieved_docs = self.retriever.get_relevant_documents(query) # Changed from invoke to get_relevant_documents
43
-
44
- # Properly access page_content from Document objects
 
45
  retrieved_text = "\n".join([doc.page_content for doc in retrieved_docs])[:1000]
46
 
47
  # Process with QA pipeline
@@ -49,13 +48,42 @@ class RAGSystem:
49
  "question": query,
50
  "context": retrieved_text
51
  }
52
- response = self.qa_pipeline(qa_input)
53
 
54
- return response['answer']
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
- def get_similar_documents(self, query, k=5):
57
  """
58
  Retrieve similar documents without processing through QA pipeline
59
  """
60
  docs = self.retriever.get_relevant_documents(query)
61
- return [{'content': doc.page_content, 'metadata': doc.metadata} for doc in docs[:k]]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  from langchain.docstore.document import Document
7
  from transformers import pipeline
8
  from langchain.prompts import PromptTemplate
9
+ from typing import List, Dict, Any, Optional
10
 
11
  class RAGSystem:
12
+ def __init__(self, sql_generator: SQLGenerator, csv_path: str = "apparel.csv"):
13
+ self.sql_generator = sql_generator
14
  self.setup_system(csv_path)
15
  self.qa_pipeline = pipeline("question-answering", model="distilbert-base-cased-distilled-squad")
16
 
17
+ def setup_system(self, csv_path: str):
18
  if not os.path.exists(csv_path):
19
  raise FileNotFoundError(f"CSV file not found at {csv_path}")
20
 
 
21
  documents = pd.read_csv(csv_path)
22
 
 
23
  docs = [
24
  Document(
25
+ page_content=str(row['Title']),
26
  metadata={'index': idx}
27
  )
28
  for idx, row in documents.iterrows()
29
  ]
30
 
 
31
  text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
32
  split_docs = text_splitter.split_documents(docs)
33
 
 
34
  embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
35
  self.vector_store = FAISS.from_documents(split_docs, embeddings)
36
  self.retriever = self.vector_store.as_retriever()
37
 
38
+ def process_query(self, query: str, execute_sql: bool = True) -> Dict[str, Any]:
39
+ """
40
+ Process a query through both RAG and SQL if needed
41
+ """
42
+ # Get relevant documents
43
+ retrieved_docs = self.retriever.get_relevant_documents(query)
44
  retrieved_text = "\n".join([doc.page_content for doc in retrieved_docs])[:1000]
45
 
46
  # Process with QA pipeline
 
48
  "question": query,
49
  "context": retrieved_text
50
  }
51
+ qa_response = self.qa_pipeline(qa_input)
52
 
53
+ result = {
54
+ "qa_answer": qa_response['answer'],
55
+ "relevant_docs": [doc.page_content for doc in retrieved_docs[:3]],
56
+ "sql_results": None
57
+ }
58
+
59
+ # If SQL execution is requested and SQL is detected in the query
60
+ if execute_sql and "SELECT" in query.upper():
61
+ if self.sql_generator.validate_query(query):
62
+ sql_results = self.sql_generator.execute_query(query)
63
+ result["sql_results"] = sql_results
64
+
65
+ return result
66
 
67
+ def get_similar_documents(self, query: str, k: int = 5) -> List[Dict[str, Any]]:
68
  """
69
  Retrieve similar documents without processing through QA pipeline
70
  """
71
  docs = self.retriever.get_relevant_documents(query)
72
+ return [{'content': doc.page_content, 'metadata': doc.metadata} for doc in docs[:k]]
73
+
74
+ # Example usage
75
+ if __name__ == "__main__":
76
+ # Initialize the SQL generator
77
+ sql_gen = SQLGenerator("shopify.db")
78
+
79
+ # Initialize the RAG system with the SQL generator
80
+ rag = RAGSystem(sql_gen, "apparel.csv")
81
+
82
+ # Example query that might include SQL
83
+ query = "SELECT * FROM products LIMIT 5"
84
+ results = rag.process_query(query)
85
+
86
+ # Access different parts of the results
87
+ print("QA Answer:", results["qa_answer"])
88
+ print("Relevant Documents:", results["relevant_docs"])
89
+ print("SQL Results:", results["sql_results"])