Ari commited on
Commit
937d1f9
·
verified ·
1 Parent(s): b383793

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -47
app.py CHANGED
@@ -1,63 +1,68 @@
1
  import streamlit as st
2
  import pandas as pd
3
  import sqlite3
4
- import plotly.express as px
5
- import json
6
-
7
- # Set paths to the default files
8
- DEFAULT_PROMPT_PATH = "prompt_engineering.json"
9
- DEFAULT_METADATA_PATH = "default_metadata.csv"
10
- DEFAULT_DATA_PATH = "default_data.csv"
11
-
12
- # Load the prompt engineering JSON file (use default if no user-uploaded prompt file)
13
- with open(DEFAULT_PROMPT_PATH) as f:
14
- prompt_data = json.load(f)
15
-
16
- # Function to find a query based on the user prompt
17
- def get_query_from_prompt(user_prompt):
18
- for item in prompt_data['prompts']:
19
- if item['question'].lower() in user_prompt.lower():
20
- return item['query']
21
- return None # Return None if no matching query is found
22
-
23
- # Step 1: Upload metadata.csv file (or use default)
24
- metadata_file = st.file_uploader("Upload your metadata.csv file", type=["csv"])
25
- if metadata_file is None:
26
- metadata = pd.read_csv(DEFAULT_METADATA_PATH)
27
- st.write("Using default metadata.csv file.")
28
- else:
29
- metadata = pd.read_csv(metadata_file)
30
- st.write("Metadata loaded successfully!")
31
- st.dataframe(metadata)
32
 
33
- # Step 2: Upload CSV data file (or use default)
34
  csv_file = st.file_uploader("Upload your CSV file", type=["csv"])
35
  if csv_file is None:
36
- data = pd.read_csv(DEFAULT_DATA_PATH)
37
  st.write("Using default data.csv file.")
38
  else:
39
  data = pd.read_csv(csv_file)
40
  st.write("Data Preview:")
41
  st.dataframe(data.head())
42
 
43
- # Step 3: Load CSV data into a SQLite database (SQL agent)
44
- conn = sqlite3.connect(':memory:') # Use an in-memory SQLite database
45
  data.to_sql('sales_data', conn, index=False, if_exists='replace')
46
 
47
- # Step 4: Get user prompt and map to SQL query
48
- user_prompt = st.text_input("Enter your natural language prompt:")
 
 
 
 
 
 
 
 
 
 
49
 
50
- # Step 5: Process the prompt and generate SQL query dynamically
 
 
 
 
 
 
 
 
 
 
 
 
51
  if user_prompt:
52
- query = get_query_from_prompt(user_prompt)
53
- if query:
54
- result = pd.read_sql(query, conn)
55
- st.write("Query Results:")
56
- st.dataframe(result)
57
-
58
- # If the query involves plotting (like "plot sales"), show the chart
59
- if "plot" in user_prompt.lower():
60
- fig = px.bar(result, x='Date', y='Sales', title="Sales Over Time")
61
- st.plotly_chart(fig)
62
- else:
63
- st.write("Sorry, I couldn't find a matching query for your prompt.")
 
1
  import streamlit as st
2
  import pandas as pd
3
  import sqlite3
4
+ import openai
5
+ from transformers import AutoTokenizer, AutoModelForCausalLM
6
+ from langchain import OpenAI
7
+ from langchain.agents import create_sql_agent
8
+ from langchain.sql_database import SQLDatabase
9
+ from langchain.chains import RetrievalQA
10
+ from langchain.document_loaders import CSVLoader
11
+ from langchain.vectorstores import FAISS
12
+ from langchain.embeddings.openai import OpenAIEmbeddings
13
+
14
+ # OpenAI API key (ensure it's stored securely)
15
+ openai.api_key = os.getenv("OPENAI_API_KEY")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
+ # Step 1: Upload CSV data file (or use default)
18
  csv_file = st.file_uploader("Upload your CSV file", type=["csv"])
19
  if csv_file is None:
20
+ data = pd.read_csv("default_data.csv") # Using default CSV
21
  st.write("Using default data.csv file.")
22
  else:
23
  data = pd.read_csv(csv_file)
24
  st.write("Data Preview:")
25
  st.dataframe(data.head())
26
 
27
+ # Step 2: Load CSV data into SQLite database (SQL agent)
28
+ conn = sqlite3.connect(':memory:') # In-memory SQLite database
29
  data.to_sql('sales_data', conn, index=False, if_exists='replace')
30
 
31
+ # Create a SQL database connection for LangChain
32
+ db = SQLDatabase.from_uri('sqlite:///:memory:')
33
+ db.raw_connection = conn
34
+
35
+ # Step 3: Use LLaMA for context retrieval (RAG)
36
+ tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
37
+ llama_model = AutoModelForCausalLM.from_pretrained("huggyllama/llama-7b")
38
+
39
+ # Load and vectorize documents for retrieval
40
+ embeddings = OpenAIEmbeddings() # Using OpenAI embeddings, but you can swap this out for another one
41
+ loader = CSVLoader(file_path=csv_file.name if csv_file else "default_data.csv")
42
+ documents = loader.load()
43
 
44
+ # Use FAISS to create a retriever from the documents
45
+ vector_store = FAISS.from_documents(documents, embeddings)
46
+ retriever = vector_store.as_retriever()
47
+
48
+ # Step 4: Create a RAG (Retrieval-Augmented Generation) chain
49
+ rag_chain = RetrievalQA.from_chain_type(llama_model, retriever=retriever)
50
+
51
+ # Step 5: Use OpenAI for SQL query generation
52
+ openai_llm = OpenAI(temperature=0) # OpenAI LLM for SQL query generation
53
+ sql_agent = create_sql_agent(openai_llm, db, verbose=True)
54
+
55
+ # Step 6: Get user prompt and augment with RAG retrieval before SQL generation
56
+ user_prompt = st.text_input("Enter your natural language prompt:")
57
  if user_prompt:
58
+ try:
59
+ # Step 7: Retrieve context using LLaMA-based RAG
60
+ rag_result = rag_chain.run(user_prompt)
61
+ st.write(f"Retrieved Context from LLaMA RAG: {rag_result}")
62
+
63
+ # Step 8: Generate and execute SQL query using OpenAI based on prompt and retrieved context
64
+ query_input = f"{user_prompt} {rag_result}"
65
+ response = sql_agent.run(query_input)
66
+ st.write(f"Generated SQL Query Results: {response}")
67
+ except Exception as e:
68
+ st.write(f"An error occurred: {e}")