Ari commited on
Commit
2599708
·
verified ·
1 Parent(s): e37eda0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -10
app.py CHANGED
@@ -4,7 +4,7 @@ import pandas as pd
4
  import sqlite3
5
  import openai
6
  from langchain import OpenAI
7
- from langchain_community.agent_toolkits.sql.base import create_sql_agent
8
  from langchain_community.utilities import SQLDatabase
9
  from langchain_community.document_loaders import CSVLoader
10
  from langchain_community.vectorstores import FAISS
@@ -34,12 +34,17 @@ data.to_sql(table_name, conn, index=False, if_exists='replace')
34
  # SQL table metadata (for validation and schema)
35
  valid_columns = list(data.columns)
36
 
37
- # Step 3: Use a SQL Agent and setup LangChain's SQL Database connection
38
  db = SQLDatabase.from_uri('sqlite:///:memory:')
39
  db.raw_connection = conn # Use the in-memory connection for LangChain
40
- sql_agent = create_sql_agent(OpenAI(temperature=0), db, verbose=True)
41
 
42
- # Step 4: Use FAISS with RAG for context retrieval
 
 
 
 
 
 
43
  embeddings = OpenAIEmbeddings()
44
  loader = CSVLoader(file_path=csv_file.name if csv_file else "default_data.csv")
45
  documents = loader.load()
@@ -48,7 +53,7 @@ vector_store = FAISS.from_documents(documents, embeddings)
48
  retriever = vector_store.as_retriever()
49
  rag_chain = RetrievalQA.from_chain_type(llm=OpenAI(temperature=0), retriever=retriever)
50
 
51
- # Step 5: Define SQL validation helpers
52
  def validate_sql(query, valid_columns):
53
  """Validates the SQL query by ensuring it references only valid columns."""
54
  for column in valid_columns:
@@ -61,25 +66,25 @@ def validate_sql_with_sqlparse(query):
61
  parsed_query = sqlparse.parse(query)
62
  return len(parsed_query) > 0
63
 
64
- # Step 6: Generate SQL query based on user input and run it with LangChain SQL Agent
65
  user_prompt = st.text_input("Enter your natural language prompt:")
66
  if user_prompt:
67
  try:
68
- # Step 7: Retrieve context using RAG
69
  context = rag_chain.run(user_prompt)
70
  st.write(f"Retrieved Context: {context}")
71
 
72
- # Step 8: Generate SQL query using SQL agent
73
  generated_sql = sql_agent.run(f"{user_prompt} {context}")
74
  st.write(f"Generated SQL Query: {generated_sql}")
75
 
76
- # Step 9: Validate SQL query
77
  if not validate_sql_with_sqlparse(generated_sql):
78
  st.write("Generated SQL is not valid.")
79
  elif not validate_sql(generated_sql, valid_columns):
80
  st.write("Generated SQL references invalid columns.")
81
  else:
82
- # Step 10: Execute SQL query
83
  result = pd.read_sql(generated_sql, conn)
84
  st.write("Query Results:")
85
  st.dataframe(result)
 
4
  import sqlite3
5
  import openai
6
  from langchain import OpenAI
7
+ from langchain_community.agent_toolkits.sql.base import SQLDatabaseToolkit, create_sql_agent # SQL toolkit import
8
  from langchain_community.utilities import SQLDatabase
9
  from langchain_community.document_loaders import CSVLoader
10
  from langchain_community.vectorstores import FAISS
 
34
  # SQL table metadata (for validation and schema)
35
  valid_columns = list(data.columns)
36
 
37
+ # Step 3: Set up the SQL Database for LangChain
38
  db = SQLDatabase.from_uri('sqlite:///:memory:')
39
  db.raw_connection = conn # Use the in-memory connection for LangChain
 
40
 
41
+ # Step 4: Create the SQL toolkit
42
+ sql_toolkit = SQLDatabaseToolkit(db)
43
+
44
+ # Step 5: Create the SQL agent using the toolkit
45
+ sql_agent = create_sql_agent(OpenAI(temperature=0), toolkit=sql_toolkit, verbose=True)
46
+
47
+ # Step 6: Use FAISS with RAG for context retrieval
48
  embeddings = OpenAIEmbeddings()
49
  loader = CSVLoader(file_path=csv_file.name if csv_file else "default_data.csv")
50
  documents = loader.load()
 
53
  retriever = vector_store.as_retriever()
54
  rag_chain = RetrievalQA.from_chain_type(llm=OpenAI(temperature=0), retriever=retriever)
55
 
56
+ # Step 7: Define SQL validation helpers
57
  def validate_sql(query, valid_columns):
58
  """Validates the SQL query by ensuring it references only valid columns."""
59
  for column in valid_columns:
 
66
  parsed_query = sqlparse.parse(query)
67
  return len(parsed_query) > 0
68
 
69
+ # Step 8: Generate SQL query based on user input and run it with LangChain SQL Agent
70
  user_prompt = st.text_input("Enter your natural language prompt:")
71
  if user_prompt:
72
  try:
73
+ # Step 9: Retrieve context using RAG
74
  context = rag_chain.run(user_prompt)
75
  st.write(f"Retrieved Context: {context}")
76
 
77
+ # Step 10: Generate SQL query using SQL agent
78
  generated_sql = sql_agent.run(f"{user_prompt} {context}")
79
  st.write(f"Generated SQL Query: {generated_sql}")
80
 
81
+ # Step 11: Validate SQL query
82
  if not validate_sql_with_sqlparse(generated_sql):
83
  st.write("Generated SQL is not valid.")
84
  elif not validate_sql(generated_sql, valid_columns):
85
  st.write("Generated SQL references invalid columns.")
86
  else:
87
+ # Step 12: Execute SQL query
88
  result = pd.read_sql(generated_sql, conn)
89
  st.write("Query Results:")
90
  st.dataframe(result)