Ari commited on
Commit
75829f5
1 Parent(s): f62bed1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -52
app.py CHANGED
@@ -1,40 +1,26 @@
1
  import streamlit as st
2
  import pandas as pd
3
- import json
4
- import os
5
  import plotly.express as px
6
- from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration
7
- from datasets import Dataset
8
 
9
  # Set paths to the default files
10
- DEFAULT_PROMPT_PATH = "default_prompt.json"
11
  DEFAULT_METADATA_PATH = "default_metadata.csv"
12
  DEFAULT_DATA_PATH = "default_data.csv"
13
 
14
- # Load the RAG model with LLaMA-based retrieval-augmented generation
15
- @st.cache_resource(allow_output_mutation=True)
16
- def load_rag_model():
17
- retriever = RagRetriever.from_pretrained("facebook/rag-token-base", index_name="custom")
18
- tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-base")
19
- model = RagSequenceForGeneration.from_pretrained("facebook/rag-token-base", retriever=retriever)
20
- return model, tokenizer, retriever
21
-
22
- model, tokenizer, retriever = load_rag_model()
23
 
24
- # Title for the app
25
- st.title("Interactive Insights Chatbot with LLaMA + RAG")
26
-
27
- # Step 1: Upload prompt.json file (or use default)
28
- prompt_file = st.file_uploader("Upload your prompt.json file", type=["json"])
29
- if prompt_file is None:
30
- with open(DEFAULT_PROMPT_PATH) as f:
31
- prompt_data = json.load(f)
32
- st.write("Using default prompt.json file.")
33
- else:
34
- prompt_data = json.load(prompt_file)
35
- st.write("Prompt JSON loaded successfully!")
36
 
37
- # Step 2: Upload metadata.csv file (or use default)
38
  metadata_file = st.file_uploader("Upload your metadata.csv file", type=["csv"])
39
  if metadata_file is None:
40
  metadata = pd.read_csv(DEFAULT_METADATA_PATH)
@@ -44,7 +30,7 @@ else:
44
  st.write("Metadata loaded successfully!")
45
  st.dataframe(metadata)
46
 
47
- # Step 3: Upload CSV data file (or use default)
48
  csv_file = st.file_uploader("Upload your CSV file", type=["csv"])
49
  if csv_file is None:
50
  data = pd.read_csv(DEFAULT_DATA_PATH)
@@ -54,31 +40,24 @@ else:
54
  st.write("Data Preview:")
55
  st.dataframe(data.head())
56
 
57
- # Convert the CSV data to a Hugging Face Dataset for retrieval
58
- dataset = Dataset.from_pandas(data)
 
59
 
60
- # Step 4: Natural language prompt input
61
  user_prompt = st.text_input("Enter your natural language prompt:")
62
 
63
- # Step 5: Process the prompt and generate insights using LLaMA + RAG
64
- if user_prompt and data is not None:
65
- st.write(f"Processing your prompt: '{user_prompt}'")
66
-
67
- # Tokenize the prompt for LLaMA + RAG
68
- inputs = tokenizer(user_prompt, return_tensors="pt")
69
-
70
- # Perform retrieval-augmented generation (RAG) by retrieving data from the dataset and generating the response
71
- generated = model.generate(input_ids=inputs['input_ids'], num_return_sequences=1, num_beams=2)
72
-
73
- # Decode the output from the LLaMA + RAG model
74
- output = tokenizer.batch_decode(generated, skip_special_tokens=True)
75
-
76
- st.write(f"Insights generated: {output[0]}")
77
-
78
- # Example: if the prompt asks for a plot (like "show sales over time")
79
- if "plot sales" in user_prompt.lower():
80
- # Create an interactive bar chart
81
- fig = px.bar(data, x='Date', y='Sales', title="Sales Over Time")
82
- st.plotly_chart(fig)
83
  else:
84
- st.write("No recognized visual request in the prompt.")
 
1
  import streamlit as st
2
  import pandas as pd
3
+ import sqlite3
 
4
  import plotly.express as px
5
+ import json
 
6
 
7
  # Set paths to the default files
8
+ DEFAULT_PROMPT_PATH = "prompt_engineering.json"
9
  DEFAULT_METADATA_PATH = "default_metadata.csv"
10
  DEFAULT_DATA_PATH = "default_data.csv"
11
 
12
+ # Load the prompt engineering JSON file (use default if no user-uploaded prompt file)
13
+ with open(DEFAULT_PROMPT_PATH) as f:
14
+ prompt_data = json.load(f)
 
 
 
 
 
 
15
 
16
+ # Function to find a query based on the user prompt
17
+ def get_query_from_prompt(user_prompt):
18
+ for item in prompt_data['prompts']:
19
+ if item['question'].lower() in user_prompt.lower():
20
+ return item['query']
21
+ return None # Return None if no matching query is found
 
 
 
 
 
 
22
 
23
+ # Step 1: Upload metadata.csv file (or use default)
24
  metadata_file = st.file_uploader("Upload your metadata.csv file", type=["csv"])
25
  if metadata_file is None:
26
  metadata = pd.read_csv(DEFAULT_METADATA_PATH)
 
30
  st.write("Metadata loaded successfully!")
31
  st.dataframe(metadata)
32
 
33
+ # Step 2: Upload CSV data file (or use default)
34
  csv_file = st.file_uploader("Upload your CSV file", type=["csv"])
35
  if csv_file is None:
36
  data = pd.read_csv(DEFAULT_DATA_PATH)
 
40
  st.write("Data Preview:")
41
  st.dataframe(data.head())
42
 
43
+ # Step 3: Load CSV data into a SQLite database (SQL agent)
44
+ conn = sqlite3.connect(':memory:') # Use an in-memory SQLite database
45
+ data.to_sql('sales_data', conn, index=False, if_exists='replace')
46
 
47
+ # Step 4: Get user prompt and map to SQL query
48
  user_prompt = st.text_input("Enter your natural language prompt:")
49
 
50
+ # Step 5: Process the prompt and generate SQL query dynamically
51
+ if user_prompt:
52
+ query = get_query_from_prompt(user_prompt)
53
+ if query:
54
+ result = pd.read_sql(query, conn)
55
+ st.write("Query Results:")
56
+ st.dataframe(result)
57
+
58
+ # If the query involves plotting (like "plot sales"), show the chart
59
+ if "plot" in user_prompt.lower():
60
+ fig = px.bar(result, x='Date', y='Sales', title="Sales Over Time")
61
+ st.plotly_chart(fig)
 
 
 
 
 
 
 
 
62
  else:
63
+ st.write("Sorry, I couldn't find a matching query for your prompt.")