Ari commited on
Commit
746d24c
·
verified ·
1 Parent(s): e1fc8bb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -47
app.py CHANGED
@@ -13,7 +13,7 @@ openai_api_key = os.getenv("OPENAI_API_KEY")
13
 
14
  # Initialize conversation history
15
  if 'conversation' not in st.session_state:
16
- st.session_state.conversation = []
17
 
18
  # Step 1: Upload CSV data file (or use default)
19
  csv_file = st.file_uploader("Upload your CSV file", type=["csv"])
@@ -33,30 +33,11 @@ data.to_sql(table_name, conn, index=False, if_exists='replace')
33
 
34
  # SQL table metadata (for validation and schema)
35
  valid_columns = list(data.columns)
36
- st.write(f"Valid columns: {valid_columns}")
37
-
38
- # Function to extract column names from the question
39
- def extract_column_name(question, valid_columns):
40
- for column in valid_columns:
41
- if column.lower() in question.lower():
42
- return column
43
- return None
44
-
45
- # Function to generate statistical insights
46
- def generate_statistical_insights(question, data):
47
- if "mean" in question.lower():
48
- column = extract_column_name(question, valid_columns)
49
- if column:
50
- mean_value = data[column].mean()
51
- st.session_state.conversation.append(f"Mean of {column}: {mean_value}")
52
- else:
53
- st.session_state.conversation.append(f"Could not find a valid column in the question.")
54
- elif "median" in question.lower():
55
- column = extract_column_name(question, valid_columns)
56
- if column:
57
- median_value = data[column].median()
58
- st.session_state.conversation.append(f"Median of {column}: {median_value}")
59
- # Add more statistical insights (mode, std, etc.)
60
 
61
  # Step 3: Define SQL validation helpers
62
  def validate_sql(query, valid_columns):
@@ -65,9 +46,8 @@ def validate_sql(query, valid_columns):
65
  columns_in_query = parser.columns
66
  for column in columns_in_query:
67
  if column not in valid_columns:
68
- st.session_state.conversation.append(f"Invalid column detected: {column}")
69
- return False
70
- return True
71
 
72
  def validate_sql_with_sqlparse(query):
73
  """Validates SQL syntax using sqlparse."""
@@ -89,36 +69,51 @@ SQL Query:
89
  prompt = PromptTemplate(template=template, input_variables=['question', 'table_name', 'columns'])
90
  sql_generation_chain = LLMChain(llm=OpenAI(temperature=0), prompt=prompt)
91
 
92
- # Display conversation history like a text thread
93
- st.write("### Conversation Thread")
94
- for chat in st.session_state.conversation:
95
- st.write(f"User: {chat}")
96
 
97
- # Step 5: Generate SQL query or statistical insights based on user input
98
- user_prompt = st.text_input("Enter your question or prompt here:")
99
  if user_prompt:
 
 
 
100
  try:
101
- # Step 6: Handle statistical insights or generate SQL
102
- if any(stat_term in user_prompt.lower() for stat_term in ["mean", "median", "mode", "std"]):
103
- generate_statistical_insights(user_prompt, data)
 
 
 
104
  else:
 
105
  columns = ', '.join(valid_columns)
106
- generated_sql = sql_generation_chain.run({'question': user_prompt, 'table_name': table_name, 'columns': columns})
 
 
 
 
107
 
108
- # Display generated SQL query in the conversation thread
109
- st.session_state.conversation.append(f"Generated SQL Query: {generated_sql}")
110
 
111
  # Step 7: Validate SQL query
112
  if not validate_sql_with_sqlparse(generated_sql):
113
- st.session_state.conversation.append("Generated SQL is not valid.")
114
- elif not validate_sql(generated_sql, valid_columns):
115
- st.session_state.conversation.append("Generated SQL references invalid columns.")
 
 
 
 
116
  else:
117
  # Step 8: Execute SQL query
118
  result = pd.read_sql_query(generated_sql, conn)
119
- st.session_state.conversation.append("Query Results:")
120
- st.session_state.conversation.append(result.to_string())
121
-
122
  except Exception as e:
123
  logging.error(f"An error occurred: {e}")
124
- st.session_state.conversation.append(f"Error: {e}")
 
 
 
 
 
13
 
14
  # Initialize conversation history
15
  if 'conversation' not in st.session_state:
16
+ st.session_state.conversation = [] # Store previous conversation messages
17
 
18
  # Step 1: Upload CSV data file (or use default)
19
  csv_file = st.file_uploader("Upload your CSV file", type=["csv"])
 
33
 
34
  # SQL table metadata (for validation and schema)
35
  valid_columns = list(data.columns)
36
+
37
+ # Display the conversation thread
38
+ st.write("### Conversation Thread:")
39
+ for message in st.session_state.conversation:
40
+ st.write(message)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
  # Step 3: Define SQL validation helpers
43
  def validate_sql(query, valid_columns):
 
46
  columns_in_query = parser.columns
47
  for column in columns_in_query:
48
  if column not in valid_columns:
49
+ return False, f"Invalid column detected: {column}"
50
+ return True, None
 
51
 
52
  def validate_sql_with_sqlparse(query):
53
  """Validates SQL syntax using sqlparse."""
 
69
  prompt = PromptTemplate(template=template, input_variables=['question', 'table_name', 'columns'])
70
  sql_generation_chain = LLMChain(llm=OpenAI(temperature=0), prompt=prompt)
71
 
72
+ # Step 5: Generate SQL query based on user input
73
+ user_prompt = st.text_input("Enter your message:")
 
 
74
 
 
 
75
  if user_prompt:
76
+ # Add user prompt to conversation history
77
+ st.session_state.conversation.append(f"User: {user_prompt}")
78
+
79
  try:
80
+ # Step 6: Adjust the logic to handle "what are the columns" query
81
+ if "columns" in user_prompt.lower():
82
+ # Custom logic to return columns
83
+ columns_response = f"The columns are: {', '.join(valid_columns)}"
84
+ st.session_state.conversation.append(f"Bot: {columns_response}")
85
+ st.write(f"The columns are: {', '.join(valid_columns)}")
86
  else:
87
+ # Generate SQL query based on user input
88
  columns = ', '.join(valid_columns)
89
+ generated_sql = sql_generation_chain.run({
90
+ 'question': user_prompt,
91
+ 'table_name': table_name,
92
+ 'columns': columns
93
+ })
94
 
95
+ # Debug: Display generated SQL query for inspection
96
+ st.write(f"Generated SQL Query:\n{generated_sql}")
97
 
98
  # Step 7: Validate SQL query
99
  if not validate_sql_with_sqlparse(generated_sql):
100
+ error_message = "Generated SQL is not valid."
101
+ st.session_state.conversation.append(f"Bot: {error_message}")
102
+ st.write(error_message)
103
+ elif not validate_sql(generated_sql, valid_columns)[0]:
104
+ invalid_column_message = "Generated SQL references invalid columns."
105
+ st.session_state.conversation.append(f"Bot: {invalid_column_message}")
106
+ st.write(invalid_column_message)
107
  else:
108
  # Step 8: Execute SQL query
109
  result = pd.read_sql_query(generated_sql, conn)
110
+ st.write("Query Results:")
111
+ st.dataframe(result)
112
+ st.session_state.conversation.append("Bot: Here are the results of your query.")
113
  except Exception as e:
114
  logging.error(f"An error occurred: {e}")
115
+ error_message = f"Error: {e}"
116
+ st.session_state.conversation.append(f"Bot: {error_message}")
117
+ st.write(f"Error: {e}")
118
+
119
+ # Persist the conversation after each message