Ari commited on
Commit
9e9d1c1
·
verified ·
1 Parent(s): e9b5d63

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +96 -36
app.py CHANGED
@@ -3,7 +3,7 @@ import streamlit as st
3
  import pandas as pd
4
  import sqlite3
5
  from langchain import OpenAI, LLMChain, PromptTemplate
6
- from langchain_community.utilities import SQLDatabase
7
  import sqlparse
8
  import logging
9
  from sql_metadata import Parser
@@ -12,47 +12,104 @@ from sql_metadata import Parser
12
  if 'history' not in st.session_state:
13
  st.session_state.history = []
14
 
15
- # Process user input
16
- user_prompt = st.text_input("Enter your message:", key='user_input')
17
 
18
- if user_prompt:
19
- try:
20
- if "columns" in user_prompt.lower():
21
- assistant_response = f"The columns are: {', '.join(valid_columns)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  st.session_state.history.append({"role": "user", "content": user_prompt})
23
- st.session_state.history.append({"role": "assistant", "content": assistant_response})
24
- else:
25
- columns = ', '.join(valid_columns)
26
- generated_sql = sql_generation_chain.run({
27
- 'question': user_prompt,
28
- 'table_name': table_name,
29
- 'columns': columns
30
- })
31
-
32
- # Validate SQL query
33
- if not validate_sql_with_sqlparse(generated_sql):
34
- assistant_response = "Generated SQL is not valid."
35
- elif not validate_sql(generated_sql, valid_columns):
36
- assistant_response = "Generated SQL references invalid columns."
37
  else:
38
- # Execute SQL query
39
- result = pd.read_sql_query(generated_sql, conn)
40
- assistant_response = f"Generated SQL Query:\n{generated_sql}\n\nQuery Results:"
41
- st.session_state.history.append({"role": "assistant", "content": result})
 
 
42
 
43
- # Append user and assistant messages to history
44
- st.session_state.history.append({"role": "user", "content": user_prompt})
45
- st.session_state.history.append({"role": "assistant", "content": assistant_response})
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
- except Exception as e:
48
- logging.error(f"An error occurred: {e}")
49
- assistant_response = f"Error: {e}"
50
- st.session_state.history.append({"role": "assistant", "content": assistant_response})
51
 
52
- # Clear the input field
53
- st.session_state.user_input = ''
54
- # Rerun the script to update the conversation display
55
- st.experimental_rerun()
56
 
57
  # Display the conversation history
58
  for message in st.session_state.history:
@@ -64,3 +121,6 @@ for message in st.session_state.history:
64
  st.dataframe(message['content'])
65
  else:
66
  st.markdown(f"**Assistant:** {message['content']}")
 
 
 
 
3
  import pandas as pd
4
  import sqlite3
5
  from langchain import OpenAI, LLMChain, PromptTemplate
6
+ # Removed unused import: from langchain_community.utilities import SQLDatabase
7
  import sqlparse
8
  import logging
9
  from sql_metadata import Parser
 
12
  if 'history' not in st.session_state:
13
  st.session_state.history = []
14
 
15
+ # OpenAI API key (ensure it is securely stored)
16
+ openai_api_key = os.getenv("OPENAI_API_KEY")
17
 
18
+ # Step 1: Upload CSV data file (or use default)
19
+ csv_file = st.file_uploader("Upload your CSV file", type=["csv"])
20
+ if csv_file is None:
21
+ data = pd.read_csv("default_data.csv") # Use default CSV if no file is uploaded
22
+ st.write("Using default_data.csv file.")
23
+ else:
24
+ data = pd.read_csv(csv_file)
25
+ st.write(f"Data Preview ({csv_file.name}):")
26
+ st.dataframe(data.head())
27
+
28
+ # Step 2: Load CSV data into a persistent SQLite database
29
+ db_file = 'my_database.db'
30
+ conn = sqlite3.connect(db_file)
31
+ table_name = csv_file.name.split('.')[0] if csv_file else "default_table"
32
+ data.to_sql(table_name, conn, index=False, if_exists='replace')
33
+
34
+ # SQL table metadata (for validation and schema)
35
+ valid_columns = list(data.columns)
36
+ st.write(f"Valid columns: {valid_columns}")
37
+
38
+ # Step 3: Define SQL validation helpers
39
+ def validate_sql(query, valid_columns):
40
+ """Validates the SQL query by ensuring it references only valid columns."""
41
+ parser = Parser(query)
42
+ columns_in_query = parser.columns
43
+ for column in columns_in_query:
44
+ if column not in valid_columns:
45
+ st.write(f"Invalid column detected: {column}")
46
+ return False
47
+ return True
48
+
49
+ def validate_sql_with_sqlparse(query):
50
+ """Validates SQL syntax using sqlparse."""
51
+ parsed_query = sqlparse.parse(query)
52
+ return len(parsed_query) > 0
53
+
54
+ # Step 4: Set up the LLM Chain to generate SQL queries
55
+ template = """
56
+ You are an expert data scientist. Given a natural language question, the name of the table, and a list of valid columns, generate a valid SQL query that answers the question.
57
+
58
+ Question: {question}
59
+
60
+ Table name: {table_name}
61
+
62
+ Valid columns: {columns}
63
+
64
+ SQL Query:
65
+ """
66
+ prompt = PromptTemplate(template=template, input_variables=['question', 'table_name', 'columns'])
67
+ sql_generation_chain = LLMChain(llm=OpenAI(temperature=0), prompt=prompt)
68
+
69
+ # Define the callback function
70
+ def process_input():
71
+ user_prompt = st.session_state['user_input']
72
+
73
+ if user_prompt:
74
+ try:
75
+ # Append user message to history
76
  st.session_state.history.append({"role": "user", "content": user_prompt})
77
+
78
+ if "columns" in user_prompt.lower():
79
+ assistant_response = f"The columns are: {', '.join(valid_columns)}"
80
+ st.session_state.history.append({"role": "assistant", "content": assistant_response})
 
 
 
 
 
 
 
 
 
 
81
  else:
82
+ columns = ', '.join(valid_columns)
83
+ generated_sql = sql_generation_chain.run({
84
+ 'question': user_prompt,
85
+ 'table_name': table_name,
86
+ 'columns': columns
87
+ })
88
 
89
+ # Debug: Display generated SQL query for inspection
90
+ # st.write(f"Generated SQL Query:\n{generated_sql}")
91
+
92
+ # Validate SQL query
93
+ if not validate_sql_with_sqlparse(generated_sql):
94
+ assistant_response = "Generated SQL is not valid."
95
+ st.session_state.history.append({"role": "assistant", "content": assistant_response})
96
+ elif not validate_sql(generated_sql, valid_columns):
97
+ assistant_response = "Generated SQL references invalid columns."
98
+ st.session_state.history.append({"role": "assistant", "content": assistant_response})
99
+ else:
100
+ # Execute SQL query
101
+ result = pd.read_sql_query(generated_sql, conn)
102
+ assistant_response = f"Generated SQL Query:\n{generated_sql}"
103
+ st.session_state.history.append({"role": "assistant", "content": assistant_response})
104
+ st.session_state.history.append({"role": "assistant", "content": result})
105
 
106
+ except Exception as e:
107
+ logging.error(f"An error occurred: {e}")
108
+ assistant_response = f"Error: {e}"
109
+ st.session_state.history.append({"role": "assistant", "content": assistant_response})
110
 
111
+ # Reset the user_input in session state
112
+ st.session_state['user_input'] = ''
 
 
113
 
114
  # Display the conversation history
115
  for message in st.session_state.history:
 
121
  st.dataframe(message['content'])
122
  else:
123
  st.markdown(f"**Assistant:** {message['content']}")
124
+
125
+ # Place the input field at the bottom with the callback
126
+ st.text_input("Enter your message:", key='user_input', on_change=process_input)