Spaces:
Running
Running
import streamlit as st | |
import pandas as pd | |
import json | |
import os | |
import plotly.express as px | |
from transformers import pipeline | |
from datasets import Dataset | |
from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration | |
# Load the LLaMA-based model with RAG | |
def load_rag_model(): | |
retriever = RagRetriever.from_pretrained("facebook/rag-token-base", index_name="custom") | |
tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-base") | |
model = RagSequenceForGeneration.from_pretrained("facebook/rag-token-base", retriever=retriever) | |
return model, tokenizer, retriever | |
model, tokenizer, retriever = load_rag_model() | |
# Title of the app | |
st.title("Interactive Insights Chatbot with LLaMA + RAG") | |
# Step 1: Upload prompt.json file | |
prompt_file = st.file_uploader("Upload your prompt.json file", type=["json"]) | |
if prompt_file: | |
prompt_data = json.load(prompt_file) | |
st.write("Prompt JSON loaded successfully!") | |
# Step 2: Upload CSV file | |
csv_file = st.file_uploader("Upload your CSV file", type=["csv"]) | |
if csv_file: | |
data = pd.read_csv(csv_file) | |
st.write("Data Preview:") | |
st.dataframe(data.head()) | |
# Convert the CSV data to a Hugging Face Dataset for retrieval | |
dataset = Dataset.from_pandas(data) | |
# Step 3: Natural language prompt input | |
user_prompt = st.text_input("Enter your natural language prompt:") | |
# Step 4: Process the user prompt and generate insights using LLaMA + RAG | |
if user_prompt and csv_file: | |
st.write(f"Processing your prompt: '{user_prompt}'") | |
# Tokenize the prompt for LLaMA + RAG | |
inputs = tokenizer(user_prompt, return_tensors="pt") | |
# Perform retrieval-augmented generation (RAG) by retrieving data from the dataset and generating the response | |
generated = model.generate(input_ids=inputs['input_ids'], num_return_sequences=1, num_beams=2) | |
# Decode the output from the LLaMA + RAG model | |
output = tokenizer.batch_decode(generated, skip_special_tokens=True) | |
st.write(f"Insights generated: {output[0]}") | |
# Example: if the prompt asks for a plot (like "show sales over time") | |
if "plot sales" in user_prompt.lower(): | |
# Create a bar chart (you can customize based on the prompt) | |
fig = px.bar(data, x='Date', y='Sales', title="Sales Over Time") | |
st.plotly_chart(fig) | |
else: | |
st.write("No recognized visual request in the prompt.") | |