Ari commited on
Commit
5671d43
·
verified ·
1 Parent(s): 174d2a6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -15
app.py CHANGED
@@ -3,12 +3,16 @@ import pandas as pd
3
  import json
4
  import os
5
  import plotly.express as px
6
- from transformers import pipeline
7
- from datasets import Dataset
8
  from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration
 
9
 
10
- # Load the LLaMA-based model with RAG
11
- @st.cache(allow_output_mutation=True)
 
 
 
 
 
12
  def load_rag_model():
13
  retriever = RagRetriever.from_pretrained("facebook/rag-token-base", index_name="custom")
14
  tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-base")
@@ -17,30 +21,47 @@ def load_rag_model():
17
 
18
  model, tokenizer, retriever = load_rag_model()
19
 
20
- # Title of the app
21
  st.title("Interactive Insights Chatbot with LLaMA + RAG")
22
 
23
- # Step 1: Upload prompt.json file
24
  prompt_file = st.file_uploader("Upload your prompt.json file", type=["json"])
25
- if prompt_file:
 
 
 
 
26
  prompt_data = json.load(prompt_file)
27
  st.write("Prompt JSON loaded successfully!")
28
 
29
- # Step 2: Upload CSV file
 
 
 
 
 
 
 
 
 
 
30
  csv_file = st.file_uploader("Upload your CSV file", type=["csv"])
31
- if csv_file:
 
 
 
32
  data = pd.read_csv(csv_file)
33
  st.write("Data Preview:")
34
  st.dataframe(data.head())
35
 
36
- # Convert the CSV data to a Hugging Face Dataset for retrieval
37
- dataset = Dataset.from_pandas(data)
38
 
39
- # Step 3: Natural language prompt input
40
  user_prompt = st.text_input("Enter your natural language prompt:")
41
 
42
- # Step 4: Process the user prompt and generate insights using LLaMA + RAG
43
- if user_prompt and csv_file:
44
  st.write(f"Processing your prompt: '{user_prompt}'")
45
 
46
  # Tokenize the prompt for LLaMA + RAG
@@ -56,7 +77,7 @@ if user_prompt and csv_file:
56
 
57
  # Example: if the prompt asks for a plot (like "show sales over time")
58
  if "plot sales" in user_prompt.lower():
59
- # Create a bar chart (you can customize based on the prompt)
60
  fig = px.bar(data, x='Date', y='Sales', title="Sales Over Time")
61
  st.plotly_chart(fig)
62
  else:
 
3
  import json
4
  import os
5
  import plotly.express as px
 
 
6
  from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration
7
+ from datasets import Dataset
8
 
9
+ # Set paths to the default files
10
+ DEFAULT_PROMPT_PATH = "default_prompt.json"
11
+ DEFAULT_METADATA_PATH = "default_metadata.csv"
12
+ DEFAULT_DATA_PATH = "default_data.csv"
13
+
14
+ # Load the RAG model with LLaMA-based retrieval-augmented generation
15
+ @st.cache_resource(allow_output_mutation=True)
16
  def load_rag_model():
17
  retriever = RagRetriever.from_pretrained("facebook/rag-token-base", index_name="custom")
18
  tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-base")
 
21
 
22
  model, tokenizer, retriever = load_rag_model()
23
 
24
+ # Title for the app
25
  st.title("Interactive Insights Chatbot with LLaMA + RAG")
26
 
27
+ # Step 1: Upload prompt.json file (or use default)
28
  prompt_file = st.file_uploader("Upload your prompt.json file", type=["json"])
29
+ if prompt_file is None:
30
+ with open(DEFAULT_PROMPT_PATH) as f:
31
+ prompt_data = json.load(f)
32
+ st.write("Using default prompt.json file.")
33
+ else:
34
  prompt_data = json.load(prompt_file)
35
  st.write("Prompt JSON loaded successfully!")
36
 
37
+ # Step 2: Upload metadata.csv file (or use default)
38
+ metadata_file = st.file_uploader("Upload your metadata.csv file", type=["csv"])
39
+ if metadata_file is None:
40
+ metadata = pd.read_csv(DEFAULT_METADATA_PATH)
41
+ st.write("Using default metadata.csv file.")
42
+ else:
43
+ metadata = pd.read_csv(metadata_file)
44
+ st.write("Metadata loaded successfully!")
45
+ st.dataframe(metadata)
46
+
47
+ # Step 3: Upload CSV data file (or use default)
48
  csv_file = st.file_uploader("Upload your CSV file", type=["csv"])
49
+ if csv_file is None:
50
+ data = pd.read_csv(DEFAULT_DATA_PATH)
51
+ st.write("Using default data.csv file.")
52
+ else:
53
  data = pd.read_csv(csv_file)
54
  st.write("Data Preview:")
55
  st.dataframe(data.head())
56
 
57
+ # Convert the CSV data to a Hugging Face Dataset for retrieval
58
+ dataset = Dataset.from_pandas(data)
59
 
60
+ # Step 4: Natural language prompt input
61
  user_prompt = st.text_input("Enter your natural language prompt:")
62
 
63
+ # Step 5: Process the prompt and generate insights using LLaMA + RAG
64
+ if user_prompt and data is not None:
65
  st.write(f"Processing your prompt: '{user_prompt}'")
66
 
67
  # Tokenize the prompt for LLaMA + RAG
 
77
 
78
  # Example: if the prompt asks for a plot (like "show sales over time")
79
  if "plot sales" in user_prompt.lower():
80
+ # Create an interactive bar chart
81
  fig = px.bar(data, x='Date', y='Sales', title="Sales Over Time")
82
  st.plotly_chart(fig)
83
  else: