Ari commited on
Commit
e37eda0
1 Parent(s): 1d9b999

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -27
app.py CHANGED
@@ -1,16 +1,17 @@
1
- import os # Add this line at the top of your file
2
  import streamlit as st
3
  import pandas as pd
4
  import sqlite3
5
  import openai
6
- from transformers import pipeline # Using Hugging Face pipeline for memory-efficient loading
7
  from langchain import OpenAI
8
  from langchain_community.agent_toolkits.sql.base import create_sql_agent
9
  from langchain_community.utilities import SQLDatabase
10
  from langchain_community.document_loaders import CSVLoader
11
  from langchain_community.vectorstores import FAISS
12
  from langchain_community.embeddings import OpenAIEmbeddings
 
13
  import sqlparse
 
14
 
15
  # OpenAI API key (ensure it is securely stored)
16
  openai.api_key = os.getenv("OPENAI_API_KEY")
@@ -27,33 +28,27 @@ else:
27
 
28
  # Step 2: Load CSV data into SQLite database with dynamic table name
29
  conn = sqlite3.connect(':memory:') # Use an in-memory SQLite database
30
-
31
- # Dynamically name the table based on the uploaded file name or fallback to a default name
32
  table_name = csv_file.name.split('.')[0] if csv_file else "default_table"
33
  data.to_sql(table_name, conn, index=False, if_exists='replace')
34
 
35
  # SQL table metadata (for validation and schema)
36
  valid_columns = list(data.columns)
37
 
38
- # Step 3: Use a smaller LLaMA for context retrieval (RAG)
39
- llama_pipeline = pipeline("text-generation", model="huggyllama/llama-2-3b-hf", device=0) # Use smaller model
 
 
40
 
41
- # Step 4: Implement RAG with FAISS for vectorized document retrieval
42
- embeddings = OpenAIEmbeddings() # You can use other embeddings if preferred
43
  loader = CSVLoader(file_path=csv_file.name if csv_file else "default_data.csv")
44
  documents = loader.load()
45
 
46
- # Use FAISS for retrieval and document search
47
  vector_store = FAISS.from_documents(documents, embeddings)
48
  retriever = vector_store.as_retriever()
 
49
 
50
- # Step 5: OpenAI for SQL query generation based on user prompt and context
51
- openai_llm = OpenAI(temperature=0)
52
- db = SQLDatabase.from_uri('sqlite:///:memory:') # Create an SQLite database for LangChain
53
- db.raw_connection = conn # Use the in-memory connection for LangChain
54
- sql_agent = create_sql_agent(openai_llm, db, verbose=True)
55
-
56
- # Step 6: Validate and Execute the SQL Query
57
  def validate_sql(query, valid_columns):
58
  """Validates the SQL query by ensuring it references only valid columns."""
59
  for column in valid_columns:
@@ -61,36 +56,34 @@ def validate_sql(query, valid_columns):
61
  return False
62
  return True
63
 
64
- # Step 7: SQL Validation with `sqlparse`
65
  def validate_sql_with_sqlparse(query):
66
  """Validates SQL syntax using sqlparse."""
67
  parsed_query = sqlparse.parse(query)
68
  return len(parsed_query) > 0
69
 
70
- # Step 8: Get user prompt, retrieve context, and generate SQL query
71
  user_prompt = st.text_input("Enter your natural language prompt:")
72
  if user_prompt:
73
  try:
74
- # Step 9: Retrieve relevant context using LLaMA RAG
75
- rag_result = llama_pipeline(user_prompt, max_length=200)
76
- st.write(f"Retrieved Context from LLaMA RAG: {rag_result}")
77
-
78
- # Step 10: Generate SQL query with OpenAI based on user prompt and retrieved context
79
- query_input = f"{user_prompt} {rag_result}"
80
- generated_sql = sql_agent.run(query_input)
81
 
 
 
82
  st.write(f"Generated SQL Query: {generated_sql}")
83
 
84
- # Step 11: Validate the SQL query before execution
85
  if not validate_sql_with_sqlparse(generated_sql):
86
  st.write("Generated SQL is not valid.")
87
  elif not validate_sql(generated_sql, valid_columns):
88
  st.write("Generated SQL references invalid columns.")
89
  else:
90
- # Step 12: Execute the SQL query
91
  result = pd.read_sql(generated_sql, conn)
92
  st.write("Query Results:")
93
  st.dataframe(result)
94
 
95
  except Exception as e:
 
96
  st.write(f"Error: {e}")
 
1
+ import os
2
  import streamlit as st
3
  import pandas as pd
4
  import sqlite3
5
  import openai
 
6
  from langchain import OpenAI
7
  from langchain_community.agent_toolkits.sql.base import create_sql_agent
8
  from langchain_community.utilities import SQLDatabase
9
  from langchain_community.document_loaders import CSVLoader
10
  from langchain_community.vectorstores import FAISS
11
  from langchain_community.embeddings import OpenAIEmbeddings
12
+ from langchain.chains import RetrievalQA
13
  import sqlparse
14
+ import logging
15
 
16
  # OpenAI API key (ensure it is securely stored)
17
  openai.api_key = os.getenv("OPENAI_API_KEY")
 
28
 
29
  # Step 2: Load CSV data into SQLite database with dynamic table name
30
  conn = sqlite3.connect(':memory:') # Use an in-memory SQLite database
 
 
31
  table_name = csv_file.name.split('.')[0] if csv_file else "default_table"
32
  data.to_sql(table_name, conn, index=False, if_exists='replace')
33
 
34
  # SQL table metadata (for validation and schema)
35
  valid_columns = list(data.columns)
36
 
37
+ # Step 3: Use a SQL Agent and setup LangChain's SQL Database connection
38
+ db = SQLDatabase.from_uri('sqlite:///:memory:')
39
+ db.raw_connection = conn # Use the in-memory connection for LangChain
40
+ sql_agent = create_sql_agent(OpenAI(temperature=0), db, verbose=True)
41
 
42
+ # Step 4: Use FAISS with RAG for context retrieval
43
+ embeddings = OpenAIEmbeddings()
44
  loader = CSVLoader(file_path=csv_file.name if csv_file else "default_data.csv")
45
  documents = loader.load()
46
 
 
47
  vector_store = FAISS.from_documents(documents, embeddings)
48
  retriever = vector_store.as_retriever()
49
+ rag_chain = RetrievalQA.from_chain_type(llm=OpenAI(temperature=0), retriever=retriever)
50
 
51
+ # Step 5: Define SQL validation helpers
 
 
 
 
 
 
52
  def validate_sql(query, valid_columns):
53
  """Validates the SQL query by ensuring it references only valid columns."""
54
  for column in valid_columns:
 
56
  return False
57
  return True
58
 
 
59
  def validate_sql_with_sqlparse(query):
60
  """Validates SQL syntax using sqlparse."""
61
  parsed_query = sqlparse.parse(query)
62
  return len(parsed_query) > 0
63
 
64
+ # Step 6: Generate SQL query based on user input and run it with LangChain SQL Agent
65
  user_prompt = st.text_input("Enter your natural language prompt:")
66
  if user_prompt:
67
  try:
68
+ # Step 7: Retrieve context using RAG
69
+ context = rag_chain.run(user_prompt)
70
+ st.write(f"Retrieved Context: {context}")
 
 
 
 
71
 
72
+ # Step 8: Generate SQL query using SQL agent
73
+ generated_sql = sql_agent.run(f"{user_prompt} {context}")
74
  st.write(f"Generated SQL Query: {generated_sql}")
75
 
76
+ # Step 9: Validate SQL query
77
  if not validate_sql_with_sqlparse(generated_sql):
78
  st.write("Generated SQL is not valid.")
79
  elif not validate_sql(generated_sql, valid_columns):
80
  st.write("Generated SQL references invalid columns.")
81
  else:
82
+ # Step 10: Execute SQL query
83
  result = pd.read_sql(generated_sql, conn)
84
  st.write("Query Results:")
85
  st.dataframe(result)
86
 
87
  except Exception as e:
88
+ logging.error(f"An error occurred: {e}")
89
  st.write(f"Error: {e}")