Ari commited on
Commit
6a2a63a
·
verified ·
1 Parent(s): 6084d8d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -21
app.py CHANGED
@@ -10,59 +10,90 @@ from langchain.chains import RetrievalQA
10
  from langchain.document_loaders import CSVLoader
11
  from langchain.vectorstores import FAISS
12
  from langchain.embeddings.openai import OpenAIEmbeddings
 
13
 
14
- # OpenAI API key (ensure it's stored securely)
15
  openai.api_key = os.getenv("OPENAI_API_KEY")
16
 
17
  # Step 1: Upload CSV data file (or use default)
18
  csv_file = st.file_uploader("Upload your CSV file", type=["csv"])
19
  if csv_file is None:
20
- data = pd.read_csv("default_data.csv") # Using default CSV
21
  st.write("Using default data.csv file.")
22
  else:
23
  data = pd.read_csv(csv_file)
24
- st.write("Data Preview:")
25
  st.dataframe(data.head())
26
 
27
- # Step 2: Load CSV data into SQLite database (SQL agent)
28
- conn = sqlite3.connect(':memory:') # In-memory SQLite database
29
- data.to_sql('sales_data', conn, index=False, if_exists='replace')
30
 
31
- # Create a SQL database connection for LangChain
32
- db = SQLDatabase.from_uri('sqlite:///:memory:')
33
- db.raw_connection = conn
 
 
 
34
 
35
  # Step 3: Use LLaMA for context retrieval (RAG)
36
  tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
37
  llama_model = AutoModelForCausalLM.from_pretrained("huggyllama/llama-7b")
38
 
39
- # Load and vectorize documents for retrieval
40
- embeddings = OpenAIEmbeddings() # Using OpenAI embeddings, but you can swap this out for another one
41
  loader = CSVLoader(file_path=csv_file.name if csv_file else "default_data.csv")
42
  documents = loader.load()
43
 
44
- # Use FAISS to create a retriever from the documents
45
  vector_store = FAISS.from_documents(documents, embeddings)
46
  retriever = vector_store.as_retriever()
47
 
48
- # Step 4: Create a RAG (Retrieval-Augmented Generation) chain
49
  rag_chain = RetrievalQA.from_chain_type(llama_model, retriever=retriever)
50
 
51
- # Step 5: Use OpenAI for SQL query generation
52
- openai_llm = OpenAI(temperature=0) # OpenAI LLM for SQL query generation
 
 
53
  sql_agent = create_sql_agent(openai_llm, db, verbose=True)
54
 
55
- # Step 6: Get user prompt and augment with RAG retrieval before SQL generation
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  user_prompt = st.text_input("Enter your natural language prompt:")
57
  if user_prompt:
58
  try:
59
- # Step 7: Retrieve context using LLaMA-based RAG
60
  rag_result = rag_chain.run(user_prompt)
61
  st.write(f"Retrieved Context from LLaMA RAG: {rag_result}")
62
 
63
- # Step 8: Generate and execute SQL query using OpenAI based on prompt and retrieved context
64
  query_input = f"{user_prompt} {rag_result}"
65
- response = sql_agent.run(query_input)
66
- st.write(f"Generated SQL Query Results: {response}")
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  except Exception as e:
68
- st.write(f"An error occurred: {e}")
 
10
  from langchain.document_loaders import CSVLoader
11
  from langchain.vectorstores import FAISS
12
  from langchain.embeddings.openai import OpenAIEmbeddings
13
+ import sqlparse
14
 
15
+ # OpenAI API key (ensure it is securely stored)
16
  openai.api_key = os.getenv("OPENAI_API_KEY")
17
 
18
  # Step 1: Upload CSV data file (or use default)
19
  csv_file = st.file_uploader("Upload your CSV file", type=["csv"])
20
  if csv_file is None:
21
+ data = pd.read_csv("default_data.csv") # Use default CSV if no file is uploaded
22
  st.write("Using default data.csv file.")
23
  else:
24
  data = pd.read_csv(csv_file)
25
+ st.write(f"Data Preview ({csv_file.name}):")
26
  st.dataframe(data.head())
27
 
28
+ # Step 2: Load CSV data into SQLite database with dynamic table name
29
+ conn = sqlite3.connect(':memory:') # Use an in-memory SQLite database
 
30
 
31
+ # Dynamically name the table based on the uploaded file name or fallback to a default name
32
+ table_name = csv_file.name.split('.')[0] if csv_file else "default_table"
33
+ data.to_sql(table_name, conn, index=False, if_exists='replace')
34
+
35
+ # SQL table metadata (for validation and schema)
36
+ valid_columns = list(data.columns)
37
 
38
  # Step 3: Use LLaMA for context retrieval (RAG)
39
  tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
40
  llama_model = AutoModelForCausalLM.from_pretrained("huggyllama/llama-7b")
41
 
42
+ # Step 4: Implement RAG with FAISS for vectorized document retrieval
43
+ embeddings = OpenAIEmbeddings() # You can use other embeddings if preferred
44
  loader = CSVLoader(file_path=csv_file.name if csv_file else "default_data.csv")
45
  documents = loader.load()
46
 
47
+ # Use FAISS for retrieval and document search
48
  vector_store = FAISS.from_documents(documents, embeddings)
49
  retriever = vector_store.as_retriever()
50
 
 
51
  rag_chain = RetrievalQA.from_chain_type(llama_model, retriever=retriever)
52
 
53
+ # Step 5: OpenAI for SQL query generation based on user prompt and context
54
+ openai_llm = OpenAI(temperature=0)
55
+ db = SQLDatabase.from_uri('sqlite:///:memory:') # Create an SQLite database for LangChain
56
+ db.raw_connection = conn # Use the in-memory connection for LangChain
57
  sql_agent = create_sql_agent(openai_llm, db, verbose=True)
58
 
59
+ # Step 6: Validate and Execute the SQL Query
60
+ def validate_sql(query, valid_columns):
61
+ """Validates the SQL query by ensuring it references only valid columns."""
62
+ for column in valid_columns:
63
+ if column not in query:
64
+ return False
65
+ return True
66
+
67
+ # Step 7: SQL Validation with `sqlparse`
68
+ def validate_sql_with_sqlparse(query):
69
+ """Validates SQL syntax using sqlparse."""
70
+ parsed_query = sqlparse.parse(query)
71
+ return len(parsed_query) > 0
72
+
73
+ # Step 8: Get user prompt, retrieve context, and generate SQL query
74
  user_prompt = st.text_input("Enter your natural language prompt:")
75
  if user_prompt:
76
  try:
77
+ # Step 9: Retrieve relevant context using LLaMA RAG
78
  rag_result = rag_chain.run(user_prompt)
79
  st.write(f"Retrieved Context from LLaMA RAG: {rag_result}")
80
 
81
+ # Step 10: Generate SQL query with OpenAI based on user prompt and retrieved context
82
  query_input = f"{user_prompt} {rag_result}"
83
+ generated_sql = sql_agent.run(query_input)
84
+
85
+ st.write(f"Generated SQL Query: {generated_sql}")
86
+
87
+ # Step 11: Validate the SQL query before execution
88
+ if not validate_sql_with_sqlparse(generated_sql):
89
+ st.write("Generated SQL is not valid.")
90
+ elif not validate_sql(generated_sql, valid_columns):
91
+ st.write("Generated SQL references invalid columns.")
92
+ else:
93
+ # Step 12: Execute the SQL query
94
+ result = pd.read_sql(generated_sql, conn)
95
+ st.write("Query Results:")
96
+ st.dataframe(result)
97
+
98
  except Exception as e:
99
+ st.write(f"Error: {e}")