Spaces:
Sleeping
Sleeping
Ari
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -3,12 +3,16 @@ import pandas as pd
|
|
3 |
import json
|
4 |
import os
|
5 |
import plotly.express as px
|
6 |
-
from transformers import pipeline
|
7 |
-
from datasets import Dataset
|
8 |
from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration
|
|
|
9 |
|
10 |
-
#
|
11 |
-
|
|
|
|
|
|
|
|
|
|
|
12 |
def load_rag_model():
|
13 |
retriever = RagRetriever.from_pretrained("facebook/rag-token-base", index_name="custom")
|
14 |
tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-base")
|
@@ -17,30 +21,47 @@ def load_rag_model():
|
|
17 |
|
18 |
model, tokenizer, retriever = load_rag_model()
|
19 |
|
20 |
-
# Title
|
21 |
st.title("Interactive Insights Chatbot with LLaMA + RAG")
|
22 |
|
23 |
-
# Step 1: Upload prompt.json file
|
24 |
prompt_file = st.file_uploader("Upload your prompt.json file", type=["json"])
|
25 |
-
if prompt_file:
|
|
|
|
|
|
|
|
|
26 |
prompt_data = json.load(prompt_file)
|
27 |
st.write("Prompt JSON loaded successfully!")
|
28 |
|
29 |
-
# Step 2: Upload
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
csv_file = st.file_uploader("Upload your CSV file", type=["csv"])
|
31 |
-
if csv_file:
|
|
|
|
|
|
|
32 |
data = pd.read_csv(csv_file)
|
33 |
st.write("Data Preview:")
|
34 |
st.dataframe(data.head())
|
35 |
|
36 |
-
|
37 |
-
|
38 |
|
39 |
-
# Step
|
40 |
user_prompt = st.text_input("Enter your natural language prompt:")
|
41 |
|
42 |
-
# Step
|
43 |
-
if user_prompt and
|
44 |
st.write(f"Processing your prompt: '{user_prompt}'")
|
45 |
|
46 |
# Tokenize the prompt for LLaMA + RAG
|
@@ -56,7 +77,7 @@ if user_prompt and csv_file:
|
|
56 |
|
57 |
# Example: if the prompt asks for a plot (like "show sales over time")
|
58 |
if "plot sales" in user_prompt.lower():
|
59 |
-
# Create
|
60 |
fig = px.bar(data, x='Date', y='Sales', title="Sales Over Time")
|
61 |
st.plotly_chart(fig)
|
62 |
else:
|
|
|
3 |
import json
|
4 |
import os
|
5 |
import plotly.express as px
|
|
|
|
|
6 |
from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration
|
7 |
+
from datasets import Dataset
|
8 |
|
9 |
+
# Set paths to the default files
|
10 |
+
DEFAULT_PROMPT_PATH = "default_prompt.json"
|
11 |
+
DEFAULT_METADATA_PATH = "default_metadata.csv"
|
12 |
+
DEFAULT_DATA_PATH = "default_data.csv"
|
13 |
+
|
14 |
+
# Load the RAG model with LLaMA-based retrieval-augmented generation
|
15 |
+
@st.cache_resource(allow_output_mutation=True)
|
16 |
def load_rag_model():
|
17 |
retriever = RagRetriever.from_pretrained("facebook/rag-token-base", index_name="custom")
|
18 |
tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-base")
|
|
|
21 |
|
22 |
model, tokenizer, retriever = load_rag_model()
|
23 |
|
24 |
+
# Title for the app
|
25 |
st.title("Interactive Insights Chatbot with LLaMA + RAG")
|
26 |
|
27 |
+
# Step 1: Upload prompt.json file (or use default)
|
28 |
prompt_file = st.file_uploader("Upload your prompt.json file", type=["json"])
|
29 |
+
if prompt_file is None:
|
30 |
+
with open(DEFAULT_PROMPT_PATH) as f:
|
31 |
+
prompt_data = json.load(f)
|
32 |
+
st.write("Using default prompt.json file.")
|
33 |
+
else:
|
34 |
prompt_data = json.load(prompt_file)
|
35 |
st.write("Prompt JSON loaded successfully!")
|
36 |
|
37 |
+
# Step 2: Upload metadata.csv file (or use default)
|
38 |
+
metadata_file = st.file_uploader("Upload your metadata.csv file", type=["csv"])
|
39 |
+
if metadata_file is None:
|
40 |
+
metadata = pd.read_csv(DEFAULT_METADATA_PATH)
|
41 |
+
st.write("Using default metadata.csv file.")
|
42 |
+
else:
|
43 |
+
metadata = pd.read_csv(metadata_file)
|
44 |
+
st.write("Metadata loaded successfully!")
|
45 |
+
st.dataframe(metadata)
|
46 |
+
|
47 |
+
# Step 3: Upload CSV data file (or use default)
|
48 |
csv_file = st.file_uploader("Upload your CSV file", type=["csv"])
|
49 |
+
if csv_file is None:
|
50 |
+
data = pd.read_csv(DEFAULT_DATA_PATH)
|
51 |
+
st.write("Using default data.csv file.")
|
52 |
+
else:
|
53 |
data = pd.read_csv(csv_file)
|
54 |
st.write("Data Preview:")
|
55 |
st.dataframe(data.head())
|
56 |
|
57 |
+
# Convert the CSV data to a Hugging Face Dataset for retrieval
|
58 |
+
dataset = Dataset.from_pandas(data)
|
59 |
|
60 |
+
# Step 4: Natural language prompt input
|
61 |
user_prompt = st.text_input("Enter your natural language prompt:")
|
62 |
|
63 |
+
# Step 5: Process the prompt and generate insights using LLaMA + RAG
|
64 |
+
if user_prompt and data is not None:
|
65 |
st.write(f"Processing your prompt: '{user_prompt}'")
|
66 |
|
67 |
# Tokenize the prompt for LLaMA + RAG
|
|
|
77 |
|
78 |
# Example: if the prompt asks for a plot (like "show sales over time")
|
79 |
if "plot sales" in user_prompt.lower():
|
80 |
+
# Create an interactive bar chart
|
81 |
fig = px.bar(data, x='Date', y='Sales', title="Sales Over Time")
|
82 |
st.plotly_chart(fig)
|
83 |
else:
|