Ari commited on
Commit
0bb1965
·
verified ·
1 Parent(s): 82bfc51

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -94
app.py CHANGED
@@ -2,33 +2,30 @@ import os
2
  import streamlit as st
3
  import pandas as pd
4
  import sqlite3
5
- import numpy as np # For numerical operations
6
  from langchain import OpenAI, LLMChain, PromptTemplate
7
  import sqlparse
8
  import logging
9
- from sklearn.linear_model import LinearRegression # For machine learning tasks
10
- from sklearn.model_selection import train_test_split
11
- from sklearn.metrics import mean_squared_error, r2_score
12
 
13
  # Initialize conversation history
14
  if 'history' not in st.session_state:
15
  st.session_state.history = []
16
 
17
- # Set up logging
18
- logging.basicConfig(level=logging.ERROR)
19
-
20
  # OpenAI API key (ensure it is securely stored)
 
21
  openai_api_key = os.getenv("OPENAI_API_KEY")
22
 
23
- # Set OpenAI API key for langchain
24
- from langchain.llms import OpenAI as LangchainOpenAI
25
- LangchainOpenAI.api_key = openai_api_key
 
26
 
27
  # Step 1: Upload CSV data file (or use default)
28
- st.title("Data Science Chatbot")
 
 
29
  csv_file = st.file_uploader("Upload your CSV file", type=["csv"])
30
  if csv_file is None:
31
- data = pd.read_csv("default_data.csv") # Use default CSV if no file is uploaded
32
  st.write("Using default_data.csv file.")
33
  else:
34
  data = pd.read_csv(csv_file)
@@ -45,30 +42,16 @@ data.to_sql(table_name, conn, index=False, if_exists='replace')
45
  valid_columns = list(data.columns)
46
  st.write(f"Valid columns: {valid_columns}")
47
 
48
- # Step 3: Define helper functions
49
-
50
- def extract_code(response):
51
- """Extracts code enclosed between <CODE> and </CODE> tags."""
52
- import re
53
- pattern = r"<CODE>(.*?)</CODE>"
54
- match = re.search(pattern, response, re.DOTALL)
55
- if match:
56
- return match.group(1).strip()
57
- else:
58
- return None
59
-
60
- # Step 4: Set up the LLM Chain to generate SQL queries or Python code
61
  template = """
62
- You are an expert data scientist assistant. Given a natural language question, the name of the table, and a list of valid columns, decide whether to generate a SQL query to retrieve data, perform statistical analysis, or build a simple machine learning model.
 
 
63
 
64
- Instructions:
65
- - If the question involves data retrieval or simple aggregations, generate a SQL query.
66
- - If the question requires statistical analysis, generate a Python code snippet using pandas and numpy.
67
- - If the question involves predictions or modeling, generate a Python code snippet using scikit-learn.
68
- - Ensure that you only use the columns provided.
69
- - Do not include any import statements in the code.
70
- - For case-insensitive string comparisons in SQL, use either 'LOWER(column) = LOWER(value)' or 'column = value COLLATE NOCASE', but do not use both together.
71
- - Provide the code between <CODE> and </CODE> tags.
72
 
73
  Question: {question}
74
 
@@ -76,12 +59,34 @@ Table name: {table_name}
76
 
77
  Valid columns: {columns}
78
 
79
- Response:
80
  """
81
  prompt = PromptTemplate(template=template, input_variables=['question', 'table_name', 'columns'])
82
- llm = LangchainOpenAI(temperature=0)
83
  sql_generation_chain = LLMChain(llm=llm, prompt=prompt)
84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  # Define the callback function
86
  def process_input():
87
  user_prompt = st.session_state['user_input']
@@ -96,59 +101,27 @@ def process_input():
96
  st.session_state.history.append({"role": "assistant", "content": assistant_response})
97
  else:
98
  columns = ', '.join(valid_columns)
99
- response = sql_generation_chain.run({
100
  'question': user_prompt,
101
  'table_name': table_name,
102
  'columns': columns
103
  })
104
 
105
- # Extract code from response
106
- code = extract_code(response)
107
- if code:
108
- # Determine if the code is SQL or Python
109
- if code.strip().lower().startswith('select'):
110
- # It's a SQL query
111
- st.write(f"Generated SQL Query:\n{code}")
112
- try:
113
- # Execute the SQL query
114
- result = pd.read_sql_query(code, conn)
115
- assistant_response = f"Generated SQL Query:\n{code}"
116
- st.session_state.history.append({"role": "assistant", "content": assistant_response})
117
- st.session_state.history.append({"role": "assistant", "content": result})
118
- except Exception as e:
119
- logging.error(f"An error occurred during SQL execution: {e}")
120
- assistant_response = f"Error executing SQL query: {e}"
121
- st.session_state.history.append({"role": "assistant", "content": assistant_response})
122
- else:
123
- # It's Python code
124
- st.write(f"Generated Python Code:\n{code}")
125
- try:
126
- # Prepare the local namespace
127
- local_vars = {
128
- 'pd': pd,
129
- 'np': np,
130
- 'data': data.copy(),
131
- 'result': None,
132
- 'LinearRegression': LinearRegression,
133
- 'train_test_split': train_test_split,
134
- 'mean_squared_error': mean_squared_error,
135
- 'r2_score': r2_score
136
- }
137
- exec(code, {}, local_vars)
138
- result = local_vars.get('result')
139
- if result is not None:
140
- assistant_response = "Result:"
141
- st.session_state.history.append({"role": "assistant", "content": assistant_response})
142
- st.session_state.history.append({"role": "assistant", "content": result})
143
- else:
144
- assistant_response = "Code executed successfully."
145
- st.session_state.history.append({"role": "assistant", "content": assistant_response})
146
- except Exception as e:
147
- logging.error(f"An error occurred during code execution: {e}")
148
- assistant_response = f"Error executing code: {e}"
149
- st.session_state.history.append({"role": "assistant", "content": assistant_response})
150
- else:
151
- assistant_response = response.strip()
152
  st.session_state.history.append({"role": "assistant", "content": assistant_response})
153
 
154
  except Exception as e:
@@ -164,17 +137,11 @@ for message in st.session_state.history:
164
  if message['role'] == 'user':
165
  st.markdown(f"**User:** {message['content']}")
166
  elif message['role'] == 'assistant':
167
- content = message['content']
168
- if isinstance(content, pd.DataFrame):
169
- st.markdown("**Assistant:** Here are the results:")
170
- st.dataframe(content)
171
- elif isinstance(content, (int, float)):
172
- st.markdown(f"**Assistant:** {content}")
173
- elif isinstance(content, dict):
174
- st.markdown("**Assistant:** Here are the results:")
175
- st.json(content)
176
  else:
177
- st.markdown(f"**Assistant:** {content}")
178
 
179
  # Place the input field at the bottom with the callback
180
  st.text_input("Enter your message:", key='user_input', on_change=process_input)
 
2
  import streamlit as st
3
  import pandas as pd
4
  import sqlite3
 
5
  from langchain import OpenAI, LLMChain, PromptTemplate
6
  import sqlparse
7
  import logging
 
 
 
8
 
9
  # Initialize conversation history
10
  if 'history' not in st.session_state:
11
  st.session_state.history = []
12
 
 
 
 
13
  # OpenAI API key (ensure it is securely stored)
14
+ # You can set the API key in your environment variables or a .env file
15
  openai_api_key = os.getenv("OPENAI_API_KEY")
16
 
17
+ # Check if the API key is set
18
+ if not openai_api_key:
19
+ st.error("OpenAI API key is not set. Please set the OPENAI_API_KEY environment variable.")
20
+ st.stop()
21
 
22
  # Step 1: Upload CSV data file (or use default)
23
+ st.title("Natural Language to SQL Query App")
24
+ st.write("Upload a CSV file to get started, or use the default dataset.")
25
+
26
  csv_file = st.file_uploader("Upload your CSV file", type=["csv"])
27
  if csv_file is None:
28
+ data = pd.read_csv("default_data.csv") # Ensure this file exists in your working directory
29
  st.write("Using default_data.csv file.")
30
  else:
31
  data = pd.read_csv(csv_file)
 
42
  valid_columns = list(data.columns)
43
  st.write(f"Valid columns: {valid_columns}")
44
 
45
+ # Step 3: Set up the LLM Chain to generate SQL queries
 
 
 
 
 
 
 
 
 
 
 
 
46
  template = """
47
+ You are an expert data scientist. Given a natural language question, the name of the table, and a list of valid columns, generate a valid SQL query that answers the question.
48
+
49
+ Ensure that:
50
 
51
+ - You only use the columns provided.
52
+ - When performing string comparisons in the WHERE clause, make them case-insensitive by using 'COLLATE NOCASE' or the LOWER() function.
53
+ - Do not use 'COLLATE NOCASE' in ORDER BY clauses unless sorting a string column.
54
+ - Do not apply 'COLLATE NOCASE' to numeric columns.
 
 
 
 
55
 
56
  Question: {question}
57
 
 
59
 
60
  Valid columns: {columns}
61
 
62
+ SQL Query:
63
  """
64
  prompt = PromptTemplate(template=template, input_variables=['question', 'table_name', 'columns'])
65
+ llm = OpenAI(temperature=0, openai_api_key=openai_api_key)
66
  sql_generation_chain = LLMChain(llm=llm, prompt=prompt)
67
 
68
+ # Optional: Clean up function to remove incorrect COLLATE NOCASE usage
69
+ def clean_sql_query(query):
70
+ """Removes incorrect usage of COLLATE NOCASE from the SQL query."""
71
+ parsed = sqlparse.parse(query)
72
+ statements = []
73
+ for stmt in parsed:
74
+ tokens = []
75
+ idx = 0
76
+ while idx < len(stmt.tokens):
77
+ token = stmt.tokens[idx]
78
+ if (token.ttype is sqlparse.tokens.Keyword and token.value.upper() == 'COLLATE'):
79
+ # Check if the next token is 'NOCASE'
80
+ next_token = stmt.tokens[idx + 2] if idx + 2 < len(stmt.tokens) else None
81
+ if next_token and next_token.value.upper() == 'NOCASE':
82
+ # Skip 'COLLATE' and 'NOCASE' tokens
83
+ idx += 3 # Skip 'COLLATE', whitespace, 'NOCASE'
84
+ continue
85
+ tokens.append(token)
86
+ idx += 1
87
+ statements.append(''.join([str(t) for t in tokens]))
88
+ return ' '.join(statements)
89
+
90
  # Define the callback function
91
  def process_input():
92
  user_prompt = st.session_state['user_input']
 
101
  st.session_state.history.append({"role": "assistant", "content": assistant_response})
102
  else:
103
  columns = ', '.join(valid_columns)
104
+ generated_sql = sql_generation_chain.run({
105
  'question': user_prompt,
106
  'table_name': table_name,
107
  'columns': columns
108
  })
109
 
110
+ # Debug: Display generated SQL query for inspection
111
+ st.write(f"Generated SQL Query:\n{generated_sql}")
112
+
113
+ # Clean the SQL query
114
+ generated_sql = clean_sql_query(generated_sql)
115
+
116
+ # Attempt to execute SQL query and handle exceptions
117
+ try:
118
+ result = pd.read_sql_query(generated_sql, conn)
119
+ assistant_response = f"Generated SQL Query:\n{generated_sql}"
120
+ st.session_state.history.append({"role": "assistant", "content": assistant_response})
121
+ st.session_state.history.append({"role": "assistant", "content": result})
122
+ except Exception as e:
123
+ logging.error(f"An error occurred during SQL execution: {e}")
124
+ assistant_response = f"Error executing SQL query: {e}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  st.session_state.history.append({"role": "assistant", "content": assistant_response})
126
 
127
  except Exception as e:
 
137
  if message['role'] == 'user':
138
  st.markdown(f"**User:** {message['content']}")
139
  elif message['role'] == 'assistant':
140
+ if isinstance(message['content'], pd.DataFrame):
141
+ st.markdown("**Assistant:** Query Results:")
142
+ st.dataframe(message['content'])
 
 
 
 
 
 
143
  else:
144
+ st.markdown(f"**Assistant:** {message['content']}")
145
 
146
  # Place the input field at the bottom with the callback
147
  st.text_input("Enter your message:", key='user_input', on_change=process_input)