Spaces:
Sleeping
Sleeping
Ari
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -79,23 +79,27 @@ def validate_sql_with_sqlparse(query):
|
|
79 |
user_prompt = st.text_input("Enter your natural language prompt:")
|
80 |
if user_prompt:
|
81 |
try:
|
82 |
-
# Step 8:
|
83 |
-
|
|
|
|
|
|
|
|
|
84 |
st.write(f"Retrieved Context: {context}")
|
85 |
|
86 |
-
# Step
|
87 |
-
generated_sql = sql_agent.run(f"{
|
88 |
|
89 |
# Debug: Display generated SQL query for inspection
|
90 |
st.write(f"Generated SQL Query: {generated_sql}")
|
91 |
|
92 |
-
# Step
|
93 |
if not validate_sql_with_sqlparse(generated_sql):
|
94 |
st.write("Generated SQL is not valid.")
|
95 |
elif not validate_sql(generated_sql, valid_columns):
|
96 |
st.write("Generated SQL references invalid columns.")
|
97 |
else:
|
98 |
-
# Step
|
99 |
result = pd.read_sql(generated_sql, conn)
|
100 |
st.write("Query Results:")
|
101 |
st.dataframe(result)
|
|
|
79 |
user_prompt = st.text_input("Enter your natural language prompt:")
|
80 |
if user_prompt:
|
81 |
try:
|
82 |
+
# Step 8: Add valid column names to the prompt
|
83 |
+
column_hints = f" Use only these columns: {', '.join(valid_columns)}"
|
84 |
+
prompt_with_columns = user_prompt + column_hints
|
85 |
+
|
86 |
+
# Step 9: Retrieve context using RAG
|
87 |
+
context = rag_chain.run(prompt_with_columns)
|
88 |
st.write(f"Retrieved Context: {context}")
|
89 |
|
90 |
+
# Step 10: Generate SQL query using SQL agent
|
91 |
+
generated_sql = sql_agent.run(f"{prompt_with_columns} {context}")
|
92 |
|
93 |
# Debug: Display generated SQL query for inspection
|
94 |
st.write(f"Generated SQL Query: {generated_sql}")
|
95 |
|
96 |
+
# Step 11: Validate SQL query
|
97 |
if not validate_sql_with_sqlparse(generated_sql):
|
98 |
st.write("Generated SQL is not valid.")
|
99 |
elif not validate_sql(generated_sql, valid_columns):
|
100 |
st.write("Generated SQL references invalid columns.")
|
101 |
else:
|
102 |
+
# Step 12: Execute SQL query
|
103 |
result = pd.read_sql(generated_sql, conn)
|
104 |
st.write("Query Results:")
|
105 |
st.dataframe(result)
|