Spaces:
Sleeping
Sleeping
Ari
commited on
Commit
•
75829f5
1
Parent(s):
f62bed1
Update app.py
Browse files
app.py
CHANGED
@@ -1,40 +1,26 @@
|
|
1 |
import streamlit as st
|
2 |
import pandas as pd
|
3 |
-
import
|
4 |
-
import os
|
5 |
import plotly.express as px
|
6 |
-
|
7 |
-
from datasets import Dataset
|
8 |
|
9 |
# Set paths to the default files
|
10 |
-
DEFAULT_PROMPT_PATH = "
|
11 |
DEFAULT_METADATA_PATH = "default_metadata.csv"
|
12 |
DEFAULT_DATA_PATH = "default_data.csv"
|
13 |
|
14 |
-
# Load the
|
15 |
-
|
16 |
-
|
17 |
-
retriever = RagRetriever.from_pretrained("facebook/rag-token-base", index_name="custom")
|
18 |
-
tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-base")
|
19 |
-
model = RagSequenceForGeneration.from_pretrained("facebook/rag-token-base", retriever=retriever)
|
20 |
-
return model, tokenizer, retriever
|
21 |
-
|
22 |
-
model, tokenizer, retriever = load_rag_model()
|
23 |
|
24 |
-
#
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
if
|
30 |
-
with open(DEFAULT_PROMPT_PATH) as f:
|
31 |
-
prompt_data = json.load(f)
|
32 |
-
st.write("Using default prompt.json file.")
|
33 |
-
else:
|
34 |
-
prompt_data = json.load(prompt_file)
|
35 |
-
st.write("Prompt JSON loaded successfully!")
|
36 |
|
37 |
-
# Step
|
38 |
metadata_file = st.file_uploader("Upload your metadata.csv file", type=["csv"])
|
39 |
if metadata_file is None:
|
40 |
metadata = pd.read_csv(DEFAULT_METADATA_PATH)
|
@@ -44,7 +30,7 @@ else:
|
|
44 |
st.write("Metadata loaded successfully!")
|
45 |
st.dataframe(metadata)
|
46 |
|
47 |
-
# Step
|
48 |
csv_file = st.file_uploader("Upload your CSV file", type=["csv"])
|
49 |
if csv_file is None:
|
50 |
data = pd.read_csv(DEFAULT_DATA_PATH)
|
@@ -54,31 +40,24 @@ else:
|
|
54 |
st.write("Data Preview:")
|
55 |
st.dataframe(data.head())
|
56 |
|
57 |
-
#
|
58 |
-
|
|
|
59 |
|
60 |
-
# Step 4:
|
61 |
user_prompt = st.text_input("Enter your natural language prompt:")
|
62 |
|
63 |
-
# Step 5: Process the prompt and generate
|
64 |
-
if user_prompt
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
st.write(f"Insights generated: {output[0]}")
|
77 |
-
|
78 |
-
# Example: if the prompt asks for a plot (like "show sales over time")
|
79 |
-
if "plot sales" in user_prompt.lower():
|
80 |
-
# Create an interactive bar chart
|
81 |
-
fig = px.bar(data, x='Date', y='Sales', title="Sales Over Time")
|
82 |
-
st.plotly_chart(fig)
|
83 |
else:
|
84 |
-
st.write("
|
|
|
1 |
import streamlit as st
|
2 |
import pandas as pd
|
3 |
+
import sqlite3
|
|
|
4 |
import plotly.express as px
|
5 |
+
import json
|
|
|
6 |
|
7 |
# Set paths to the default files
|
8 |
+
DEFAULT_PROMPT_PATH = "prompt_engineering.json"
|
9 |
DEFAULT_METADATA_PATH = "default_metadata.csv"
|
10 |
DEFAULT_DATA_PATH = "default_data.csv"
|
11 |
|
12 |
+
# Load the prompt engineering JSON file (use default if no user-uploaded prompt file)
|
13 |
+
with open(DEFAULT_PROMPT_PATH) as f:
|
14 |
+
prompt_data = json.load(f)
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
+
# Function to find a query based on the user prompt
|
17 |
+
def get_query_from_prompt(user_prompt):
|
18 |
+
for item in prompt_data['prompts']:
|
19 |
+
if item['question'].lower() in user_prompt.lower():
|
20 |
+
return item['query']
|
21 |
+
return None # Return None if no matching query is found
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
+
# Step 1: Upload metadata.csv file (or use default)
|
24 |
metadata_file = st.file_uploader("Upload your metadata.csv file", type=["csv"])
|
25 |
if metadata_file is None:
|
26 |
metadata = pd.read_csv(DEFAULT_METADATA_PATH)
|
|
|
30 |
st.write("Metadata loaded successfully!")
|
31 |
st.dataframe(metadata)
|
32 |
|
33 |
+
# Step 2: Upload CSV data file (or use default)
|
34 |
csv_file = st.file_uploader("Upload your CSV file", type=["csv"])
|
35 |
if csv_file is None:
|
36 |
data = pd.read_csv(DEFAULT_DATA_PATH)
|
|
|
40 |
st.write("Data Preview:")
|
41 |
st.dataframe(data.head())
|
42 |
|
43 |
+
# Step 3: Load CSV data into a SQLite database (SQL agent)
|
44 |
+
conn = sqlite3.connect(':memory:') # Use an in-memory SQLite database
|
45 |
+
data.to_sql('sales_data', conn, index=False, if_exists='replace')
|
46 |
|
47 |
+
# Step 4: Get user prompt and map to SQL query
|
48 |
user_prompt = st.text_input("Enter your natural language prompt:")
|
49 |
|
50 |
+
# Step 5: Process the prompt and generate SQL query dynamically
|
51 |
+
if user_prompt:
|
52 |
+
query = get_query_from_prompt(user_prompt)
|
53 |
+
if query:
|
54 |
+
result = pd.read_sql(query, conn)
|
55 |
+
st.write("Query Results:")
|
56 |
+
st.dataframe(result)
|
57 |
+
|
58 |
+
# If the query involves plotting (like "plot sales"), show the chart
|
59 |
+
if "plot" in user_prompt.lower():
|
60 |
+
fig = px.bar(result, x='Date', y='Sales', title="Sales Over Time")
|
61 |
+
st.plotly_chart(fig)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
else:
|
63 |
+
st.write("Sorry, I couldn't find a matching query for your prompt.")
|