Ari commited on
Commit
3eb59a4
·
verified ·
1 Parent(s): ad6c7c2

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -0
app.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import json
4
+ import os
5
+ import plotly.express as px
6
+ from transformers import pipeline
7
+ from datasets import Dataset
8
+ from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration
9
+
10
+ # Load the LLaMA-based model with RAG
11
+ @st.cache(allow_output_mutation=True)
12
+ def load_rag_model():
13
+ retriever = RagRetriever.from_pretrained("facebook/rag-token-base", index_name="custom")
14
+ tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-base")
15
+ model = RagSequenceForGeneration.from_pretrained("facebook/rag-token-base", retriever=retriever)
16
+ return model, tokenizer, retriever
17
+
18
+ model, tokenizer, retriever = load_rag_model()
19
+
20
+ # Title of the app
21
+ st.title("Interactive Insights Chatbot with LLaMA + RAG")
22
+
23
+ # Step 1: Upload prompt.json file
24
+ prompt_file = st.file_uploader("Upload your prompt.json file", type=["json"])
25
+ if prompt_file:
26
+ prompt_data = json.load(prompt_file)
27
+ st.write("Prompt JSON loaded successfully!")
28
+
29
+ # Step 2: Upload CSV file
30
+ csv_file = st.file_uploader("Upload your CSV file", type=["csv"])
31
+ if csv_file:
32
+ data = pd.read_csv(csv_file)
33
+ st.write("Data Preview:")
34
+ st.dataframe(data.head())
35
+
36
+ # Convert the CSV data to a Hugging Face Dataset for retrieval
37
+ dataset = Dataset.from_pandas(data)
38
+
39
+ # Step 3: Natural language prompt input
40
+ user_prompt = st.text_input("Enter your natural language prompt:")
41
+
42
+ # Step 4: Process the user prompt and generate insights using LLaMA + RAG
43
+ if user_prompt and csv_file:
44
+ st.write(f"Processing your prompt: '{user_prompt}'")
45
+
46
+ # Tokenize the prompt for LLaMA + RAG
47
+ inputs = tokenizer(user_prompt, return_tensors="pt")
48
+
49
+ # Perform retrieval-augmented generation (RAG) by retrieving data from the dataset and generating the response
50
+ generated = model.generate(input_ids=inputs['input_ids'], num_return_sequences=1, num_beams=2)
51
+
52
+ # Decode the output from the LLaMA + RAG model
53
+ output = tokenizer.batch_decode(generated, skip_special_tokens=True)
54
+
55
+ st.write(f"Insights generated: {output[0]}")
56
+
57
+ # Example: if the prompt asks for a plot (like "show sales over time")
58
+ if "plot sales" in user_prompt.lower():
59
+ # Create a bar chart (you can customize based on the prompt)
60
+ fig = px.bar(data, x='Date', y='Sales', title="Sales Over Time")
61
+ st.plotly_chart(fig)
62
+ else:
63
+ st.write("No recognized visual request in the prompt.")