Ari commited on
Commit
82bfc51
·
verified ·
1 Parent(s): 2d80a49

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +155 -66
app.py CHANGED
@@ -2,90 +2,179 @@ import os
2
  import streamlit as st
3
  import pandas as pd
4
  import sqlite3
5
- import openai
6
- from langchain import OpenAI
7
- from langchain_community.agent_toolkits.sql.base import create_sql_agent
8
- from langchain_community.utilities import SQLDatabase
9
- from langchain_community.document_loaders import CSVLoader
10
- from langchain_community.vectorstores import FAISS
11
- from langchain_community.embeddings import OpenAIEmbeddings
12
- from langchain.chains import RetrievalQA
13
  import sqlparse
14
  import logging
 
 
 
 
 
 
 
 
 
 
15
 
16
  # OpenAI API key (ensure it is securely stored)
17
- openai.api_key = os.getenv("OPENAI_API_KEY")
 
 
 
 
18
 
19
  # Step 1: Upload CSV data file (or use default)
 
20
  csv_file = st.file_uploader("Upload your CSV file", type=["csv"])
21
  if csv_file is None:
22
  data = pd.read_csv("default_data.csv") # Use default CSV if no file is uploaded
23
- st.write("Using default data.csv file.")
24
  else:
25
  data = pd.read_csv(csv_file)
26
  st.write(f"Data Preview ({csv_file.name}):")
27
  st.dataframe(data.head())
28
 
29
- # Step 2: Load CSV data into SQLite database with dynamic table name
30
- conn = sqlite3.connect(':memory:') # Use an in-memory SQLite database
 
31
  table_name = csv_file.name.split('.')[0] if csv_file else "default_table"
32
  data.to_sql(table_name, conn, index=False, if_exists='replace')
33
 
34
  # SQL table metadata (for validation and schema)
35
  valid_columns = list(data.columns)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
- # Step 3: Set up the SQL Database for LangChain
38
- db = SQLDatabase.from_uri('sqlite:///:memory:')
39
- db.raw_connection = conn # Use the in-memory connection for LangChain
40
-
41
- # Step 4: Create the SQL agent with the correct parameter name
42
- sql_agent = create_sql_agent(OpenAI(temperature=0), db=db, verbose=True)
43
-
44
- # Step 5: Use FAISS with RAG for context retrieval
45
- embeddings = OpenAIEmbeddings()
46
- loader = CSVLoader(file_path=csv_file.name if csv_file else "default_data.csv")
47
- documents = loader.load()
48
-
49
- vector_store = FAISS.from_documents(documents, embeddings)
50
- retriever = vector_store.as_retriever()
51
- rag_chain = RetrievalQA.from_chain_type(llm=OpenAI(temperature=0), retriever=retriever)
52
-
53
- # Step 6: Define SQL validation helpers
54
- def validate_sql(query, valid_columns):
55
- """Validates the SQL query by ensuring it references only valid columns."""
56
- for column in valid_columns:
57
- if column not in query:
58
- return False
59
- return True
60
-
61
- def validate_sql_with_sqlparse(query):
62
- """Validates SQL syntax using sqlparse."""
63
- parsed_query = sqlparse.parse(query)
64
- return len(parsed_query) > 0
65
-
66
- # Step 7: Generate SQL query based on user input and run it with LangChain SQL Agent
67
- user_prompt = st.text_input("Enter your natural language prompt:")
68
- if user_prompt:
69
- try:
70
- # Step 8: Retrieve context using RAG
71
- context = rag_chain.run(user_prompt)
72
- st.write(f"Retrieved Context: {context}")
73
-
74
- # Step 9: Generate SQL query using SQL agent
75
- generated_sql = sql_agent.run(f"{user_prompt} {context}")
76
- st.write(f"Generated SQL Query: {generated_sql}")
77
-
78
- # Step 10: Validate SQL query
79
- if not validate_sql_with_sqlparse(generated_sql):
80
- st.write("Generated SQL is not valid.")
81
- elif not validate_sql(generated_sql, valid_columns):
82
- st.write("Generated SQL references invalid columns.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  else:
84
- # Step 11: Execute SQL query
85
- result = pd.read_sql(generated_sql, conn)
86
- st.write("Query Results:")
87
- st.dataframe(result)
88
-
89
- except Exception as e:
90
- logging.error(f"An error occurred: {e}")
91
- st.write(f"Error: {e}")
 
2
  import streamlit as st
3
  import pandas as pd
4
  import sqlite3
5
+ import numpy as np # For numerical operations
6
+ from langchain import OpenAI, LLMChain, PromptTemplate
 
 
 
 
 
 
7
  import sqlparse
8
  import logging
9
+ from sklearn.linear_model import LinearRegression # For machine learning tasks
10
+ from sklearn.model_selection import train_test_split
11
+ from sklearn.metrics import mean_squared_error, r2_score
12
+
13
+ # Initialize conversation history
14
+ if 'history' not in st.session_state:
15
+ st.session_state.history = []
16
+
17
+ # Set up logging
18
+ logging.basicConfig(level=logging.ERROR)
19
 
20
  # OpenAI API key (ensure it is securely stored)
21
+ openai_api_key = os.getenv("OPENAI_API_KEY")
22
+
23
+ # Set OpenAI API key for langchain
24
+ from langchain.llms import OpenAI as LangchainOpenAI
25
+ LangchainOpenAI.api_key = openai_api_key
26
 
27
  # Step 1: Upload CSV data file (or use default)
28
+ st.title("Data Science Chatbot")
29
  csv_file = st.file_uploader("Upload your CSV file", type=["csv"])
30
  if csv_file is None:
31
  data = pd.read_csv("default_data.csv") # Use default CSV if no file is uploaded
32
+ st.write("Using default_data.csv file.")
33
  else:
34
  data = pd.read_csv(csv_file)
35
  st.write(f"Data Preview ({csv_file.name}):")
36
  st.dataframe(data.head())
37
 
38
+ # Step 2: Load CSV data into a persistent SQLite database
39
+ db_file = 'my_database.db'
40
+ conn = sqlite3.connect(db_file)
41
  table_name = csv_file.name.split('.')[0] if csv_file else "default_table"
42
  data.to_sql(table_name, conn, index=False, if_exists='replace')
43
 
44
  # SQL table metadata (for validation and schema)
45
  valid_columns = list(data.columns)
46
+ st.write(f"Valid columns: {valid_columns}")
47
+
48
+ # Step 3: Define helper functions
49
+
50
+ def extract_code(response):
51
+ """Extracts code enclosed between <CODE> and </CODE> tags."""
52
+ import re
53
+ pattern = r"<CODE>(.*?)</CODE>"
54
+ match = re.search(pattern, response, re.DOTALL)
55
+ if match:
56
+ return match.group(1).strip()
57
+ else:
58
+ return None
59
+
60
+ # Step 4: Set up the LLM Chain to generate SQL queries or Python code
61
+ template = """
62
+ You are an expert data scientist assistant. Given a natural language question, the name of the table, and a list of valid columns, decide whether to generate a SQL query to retrieve data, perform statistical analysis, or build a simple machine learning model.
63
+
64
+ Instructions:
65
+ - If the question involves data retrieval or simple aggregations, generate a SQL query.
66
+ - If the question requires statistical analysis, generate a Python code snippet using pandas and numpy.
67
+ - If the question involves predictions or modeling, generate a Python code snippet using scikit-learn.
68
+ - Ensure that you only use the columns provided.
69
+ - Do not include any import statements in the code.
70
+ - For case-insensitive string comparisons in SQL, use either 'LOWER(column) = LOWER(value)' or 'column = value COLLATE NOCASE', but do not use both together.
71
+ - Provide the code between <CODE> and </CODE> tags.
72
+
73
+ Question: {question}
74
+
75
+ Table name: {table_name}
76
 
77
+ Valid columns: {columns}
78
+
79
+ Response:
80
+ """
81
+ prompt = PromptTemplate(template=template, input_variables=['question', 'table_name', 'columns'])
82
+ llm = LangchainOpenAI(temperature=0)
83
+ sql_generation_chain = LLMChain(llm=llm, prompt=prompt)
84
+
85
+ # Define the callback function
86
+ def process_input():
87
+ user_prompt = st.session_state['user_input']
88
+
89
+ if user_prompt:
90
+ try:
91
+ # Append user message to history
92
+ st.session_state.history.append({"role": "user", "content": user_prompt})
93
+
94
+ if "columns" in user_prompt.lower():
95
+ assistant_response = f"The columns are: {', '.join(valid_columns)}"
96
+ st.session_state.history.append({"role": "assistant", "content": assistant_response})
97
+ else:
98
+ columns = ', '.join(valid_columns)
99
+ response = sql_generation_chain.run({
100
+ 'question': user_prompt,
101
+ 'table_name': table_name,
102
+ 'columns': columns
103
+ })
104
+
105
+ # Extract code from response
106
+ code = extract_code(response)
107
+ if code:
108
+ # Determine if the code is SQL or Python
109
+ if code.strip().lower().startswith('select'):
110
+ # It's a SQL query
111
+ st.write(f"Generated SQL Query:\n{code}")
112
+ try:
113
+ # Execute the SQL query
114
+ result = pd.read_sql_query(code, conn)
115
+ assistant_response = f"Generated SQL Query:\n{code}"
116
+ st.session_state.history.append({"role": "assistant", "content": assistant_response})
117
+ st.session_state.history.append({"role": "assistant", "content": result})
118
+ except Exception as e:
119
+ logging.error(f"An error occurred during SQL execution: {e}")
120
+ assistant_response = f"Error executing SQL query: {e}"
121
+ st.session_state.history.append({"role": "assistant", "content": assistant_response})
122
+ else:
123
+ # It's Python code
124
+ st.write(f"Generated Python Code:\n{code}")
125
+ try:
126
+ # Prepare the local namespace
127
+ local_vars = {
128
+ 'pd': pd,
129
+ 'np': np,
130
+ 'data': data.copy(),
131
+ 'result': None,
132
+ 'LinearRegression': LinearRegression,
133
+ 'train_test_split': train_test_split,
134
+ 'mean_squared_error': mean_squared_error,
135
+ 'r2_score': r2_score
136
+ }
137
+ exec(code, {}, local_vars)
138
+ result = local_vars.get('result')
139
+ if result is not None:
140
+ assistant_response = "Result:"
141
+ st.session_state.history.append({"role": "assistant", "content": assistant_response})
142
+ st.session_state.history.append({"role": "assistant", "content": result})
143
+ else:
144
+ assistant_response = "Code executed successfully."
145
+ st.session_state.history.append({"role": "assistant", "content": assistant_response})
146
+ except Exception as e:
147
+ logging.error(f"An error occurred during code execution: {e}")
148
+ assistant_response = f"Error executing code: {e}"
149
+ st.session_state.history.append({"role": "assistant", "content": assistant_response})
150
+ else:
151
+ assistant_response = response.strip()
152
+ st.session_state.history.append({"role": "assistant", "content": assistant_response})
153
+
154
+ except Exception as e:
155
+ logging.error(f"An error occurred: {e}")
156
+ assistant_response = f"Error: {e}"
157
+ st.session_state.history.append({"role": "assistant", "content": assistant_response})
158
+
159
+ # Reset the user_input in session state
160
+ st.session_state['user_input'] = ''
161
+
162
+ # Display the conversation history
163
+ for message in st.session_state.history:
164
+ if message['role'] == 'user':
165
+ st.markdown(f"**User:** {message['content']}")
166
+ elif message['role'] == 'assistant':
167
+ content = message['content']
168
+ if isinstance(content, pd.DataFrame):
169
+ st.markdown("**Assistant:** Here are the results:")
170
+ st.dataframe(content)
171
+ elif isinstance(content, (int, float)):
172
+ st.markdown(f"**Assistant:** {content}")
173
+ elif isinstance(content, dict):
174
+ st.markdown("**Assistant:** Here are the results:")
175
+ st.json(content)
176
  else:
177
+ st.markdown(f"**Assistant:** {content}")
178
+
179
+ # Place the input field at the bottom with the callback
180
+ st.text_input("Enter your message:", key='user_input', on_change=process_input)