Ari commited on
Commit
2d80a49
·
verified ·
1 Parent(s): df5408a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -77
app.py CHANGED
@@ -2,18 +2,19 @@ import os
2
  import streamlit as st
3
  import pandas as pd
4
  import sqlite3
5
- from langchain import OpenAI, LLMChain, PromptTemplate
 
 
6
  from langchain_community.utilities import SQLDatabase
 
 
 
 
7
  import sqlparse
8
  import logging
9
- from sql_metadata import Parser
10
 
11
  # OpenAI API key (ensure it is securely stored)
12
- openai_api_key = os.getenv("OPENAI_API_KEY")
13
-
14
- # Initialize conversation history
15
- if 'conversation' not in st.session_state:
16
- st.session_state.conversation = [] # Store previous conversation messages
17
 
18
  # Step 1: Upload CSV data file (or use default)
19
  csv_file = st.file_uploader("Upload your CSV file", type=["csv"])
@@ -25,94 +26,66 @@ else:
25
  st.write(f"Data Preview ({csv_file.name}):")
26
  st.dataframe(data.head())
27
 
28
- # Step 2: Load CSV data into a persistent SQLite database
29
- db_file = 'my_database.db'
30
- conn = sqlite3.connect(db_file)
31
  table_name = csv_file.name.split('.')[0] if csv_file else "default_table"
32
  data.to_sql(table_name, conn, index=False, if_exists='replace')
33
 
34
  # SQL table metadata (for validation and schema)
35
  valid_columns = list(data.columns)
36
 
37
- # Display the conversation thread
38
- st.markdown("### Conversation Thread:")
39
- for message in st.session_state.conversation:
40
- if message.startswith("User:"):
41
- st.markdown(f"<p style='color:blue'><strong>{message}</strong></p>", unsafe_allow_html=True)
42
- else:
43
- st.markdown(f"<p style='color:green'><strong>{message}</strong></p>", unsafe_allow_html=True)
 
 
 
 
44
 
45
- # Step 3: Define SQL validation helpers
 
 
 
 
46
  def validate_sql(query, valid_columns):
47
  """Validates the SQL query by ensuring it references only valid columns."""
48
- parser = Parser(query)
49
- columns_in_query = parser.columns
50
- for column in columns_in_query:
51
- if column not in valid_columns:
52
- return False, f"Invalid column detected: {column}"
53
- return True, None
54
 
55
  def validate_sql_with_sqlparse(query):
56
  """Validates SQL syntax using sqlparse."""
57
  parsed_query = sqlparse.parse(query)
58
  return len(parsed_query) > 0
59
 
60
- # Step 4: Set up the LLM Chain to generate SQL queries
61
- template = """
62
- You are an expert data scientist. Given a natural language question, the name of the table, and a list of valid columns, generate a valid SQL query that answers the question.
63
-
64
- Question: {question}
65
-
66
- Table name: {table_name}
67
-
68
- Valid columns: {columns}
69
-
70
- SQL Query:
71
- """
72
- prompt = PromptTemplate(template=template, input_variables=['question', 'table_name', 'columns'])
73
- sql_generation_chain = LLMChain(llm=OpenAI(temperature=0), prompt=prompt)
74
-
75
- # Step 5: Generate SQL query based on user input
76
- user_prompt = st.text_input("Enter your message:")
77
-
78
  if user_prompt:
79
- # Add user prompt to conversation history
80
- st.session_state.conversation.append(f"User: {user_prompt}")
81
-
82
  try:
83
- # Step 6: Adjust the logic to handle "what are the columns" query
84
- if "columns" in user_prompt.lower():
85
- # Custom logic to return columns
86
- columns_response = f"The columns are: {', '.join(valid_columns)}"
87
- st.session_state.conversation.append(f"Bot: {columns_response}")
 
 
 
 
 
 
 
 
88
  else:
89
- # Generate SQL query based on user input
90
- columns = ', '.join(valid_columns)
91
- generated_sql = sql_generation_chain.run({
92
- 'question': user_prompt,
93
- 'table_name': table_name,
94
- 'columns': columns
95
- })
96
-
97
- # Debug: Display generated SQL query for inspection
98
- st.session_state.conversation.append(f"Bot: Generated SQL Query:\n{generated_sql}")
99
 
100
- # Step 7: Validate SQL query
101
- if not validate_sql_with_sqlparse(generated_sql):
102
- error_message = "Generated SQL is not valid."
103
- st.session_state.conversation.append(f"Bot: {error_message}")
104
- elif not validate_sql(generated_sql, valid_columns)[0]:
105
- invalid_column_message = "Generated SQL references invalid columns."
106
- st.session_state.conversation.append(f"Bot: {invalid_column_message}")
107
- else:
108
- # Step 8: Execute SQL query
109
- result = pd.read_sql_query(generated_sql, conn)
110
- st.session_state.conversation.append("Bot: Here are the results of your query:")
111
- st.session_state.conversation.append(result.to_string(index=False)) # Add query result as string
112
-
113
  except Exception as e:
114
  logging.error(f"An error occurred: {e}")
115
- error_message = f"Error: {e}"
116
- st.session_state.conversation.append(f"Bot: {error_message}")
117
-
118
- # Persist the conversation after each message
 
2
  import streamlit as st
3
  import pandas as pd
4
  import sqlite3
5
+ import openai
6
+ from langchain import OpenAI
7
+ from langchain_community.agent_toolkits.sql.base import create_sql_agent
8
  from langchain_community.utilities import SQLDatabase
9
+ from langchain_community.document_loaders import CSVLoader
10
+ from langchain_community.vectorstores import FAISS
11
+ from langchain_community.embeddings import OpenAIEmbeddings
12
+ from langchain.chains import RetrievalQA
13
  import sqlparse
14
  import logging
 
15
 
16
  # OpenAI API key (ensure it is securely stored)
17
+ openai.api_key = os.getenv("OPENAI_API_KEY")
 
 
 
 
18
 
19
  # Step 1: Upload CSV data file (or use default)
20
  csv_file = st.file_uploader("Upload your CSV file", type=["csv"])
 
26
  st.write(f"Data Preview ({csv_file.name}):")
27
  st.dataframe(data.head())
28
 
29
+ # Step 2: Load CSV data into SQLite database with dynamic table name
30
+ conn = sqlite3.connect(':memory:') # Use an in-memory SQLite database
 
31
  table_name = csv_file.name.split('.')[0] if csv_file else "default_table"
32
  data.to_sql(table_name, conn, index=False, if_exists='replace')
33
 
34
  # SQL table metadata (for validation and schema)
35
  valid_columns = list(data.columns)
36
 
37
+ # Step 3: Set up the SQL Database for LangChain
38
+ db = SQLDatabase.from_uri('sqlite:///:memory:')
39
+ db.raw_connection = conn # Use the in-memory connection for LangChain
40
+
41
+ # Step 4: Create the SQL agent with the correct parameter name
42
+ sql_agent = create_sql_agent(OpenAI(temperature=0), db=db, verbose=True)
43
+
44
+ # Step 5: Use FAISS with RAG for context retrieval
45
+ embeddings = OpenAIEmbeddings()
46
+ loader = CSVLoader(file_path=csv_file.name if csv_file else "default_data.csv")
47
+ documents = loader.load()
48
 
49
+ vector_store = FAISS.from_documents(documents, embeddings)
50
+ retriever = vector_store.as_retriever()
51
+ rag_chain = RetrievalQA.from_chain_type(llm=OpenAI(temperature=0), retriever=retriever)
52
+
53
+ # Step 6: Define SQL validation helpers
54
  def validate_sql(query, valid_columns):
55
  """Validates the SQL query by ensuring it references only valid columns."""
56
+ for column in valid_columns:
57
+ if column not in query:
58
+ return False
59
+ return True
 
 
60
 
61
  def validate_sql_with_sqlparse(query):
62
  """Validates SQL syntax using sqlparse."""
63
  parsed_query = sqlparse.parse(query)
64
  return len(parsed_query) > 0
65
 
66
+ # Step 7: Generate SQL query based on user input and run it with LangChain SQL Agent
67
+ user_prompt = st.text_input("Enter your natural language prompt:")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  if user_prompt:
 
 
 
69
  try:
70
+ # Step 8: Retrieve context using RAG
71
+ context = rag_chain.run(user_prompt)
72
+ st.write(f"Retrieved Context: {context}")
73
+
74
+ # Step 9: Generate SQL query using SQL agent
75
+ generated_sql = sql_agent.run(f"{user_prompt} {context}")
76
+ st.write(f"Generated SQL Query: {generated_sql}")
77
+
78
+ # Step 10: Validate SQL query
79
+ if not validate_sql_with_sqlparse(generated_sql):
80
+ st.write("Generated SQL is not valid.")
81
+ elif not validate_sql(generated_sql, valid_columns):
82
+ st.write("Generated SQL references invalid columns.")
83
  else:
84
+ # Step 11: Execute SQL query
85
+ result = pd.read_sql(generated_sql, conn)
86
+ st.write("Query Results:")
87
+ st.dataframe(result)
 
 
 
 
 
 
88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  except Exception as e:
90
  logging.error(f"An error occurred: {e}")
91
+ st.write(f"Error: {e}")