Ari commited on
Commit
a511dd2
·
verified ·
1 Parent(s): 73b2770

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -6
app.py CHANGED
@@ -79,23 +79,27 @@ def validate_sql_with_sqlparse(query):
79
  user_prompt = st.text_input("Enter your natural language prompt:")
80
  if user_prompt:
81
  try:
82
- # Step 8: Retrieve context using RAG
83
- context = rag_chain.run(user_prompt)
 
 
 
 
84
  st.write(f"Retrieved Context: {context}")
85
 
86
- # Step 9: Generate SQL query using SQL agent
87
- generated_sql = sql_agent.run(f"{user_prompt} {context}")
88
 
89
  # Debug: Display generated SQL query for inspection
90
  st.write(f"Generated SQL Query: {generated_sql}")
91
 
92
- # Step 10: Validate SQL query
93
  if not validate_sql_with_sqlparse(generated_sql):
94
  st.write("Generated SQL is not valid.")
95
  elif not validate_sql(generated_sql, valid_columns):
96
  st.write("Generated SQL references invalid columns.")
97
  else:
98
- # Step 11: Execute SQL query
99
  result = pd.read_sql(generated_sql, conn)
100
  st.write("Query Results:")
101
  st.dataframe(result)
 
79
  user_prompt = st.text_input("Enter your natural language prompt:")
80
  if user_prompt:
81
  try:
82
+ # Step 8: Add valid column names to the prompt
83
+ column_hints = f" Use only these columns: {', '.join(valid_columns)}"
84
+ prompt_with_columns = user_prompt + column_hints
85
+
86
+ # Step 9: Retrieve context using RAG
87
+ context = rag_chain.run(prompt_with_columns)
88
  st.write(f"Retrieved Context: {context}")
89
 
90
+ # Step 10: Generate SQL query using SQL agent
91
+ generated_sql = sql_agent.run(f"{prompt_with_columns} {context}")
92
 
93
  # Debug: Display generated SQL query for inspection
94
  st.write(f"Generated SQL Query: {generated_sql}")
95
 
96
+ # Step 11: Validate SQL query
97
  if not validate_sql_with_sqlparse(generated_sql):
98
  st.write("Generated SQL is not valid.")
99
  elif not validate_sql(generated_sql, valid_columns):
100
  st.write("Generated SQL references invalid columns.")
101
  else:
102
+ # Step 12: Execute SQL query
103
  result = pd.read_sql(generated_sql, conn)
104
  st.write("Query Results:")
105
  st.dataframe(result)