import streamlit as st import pandas as pd import json import os import plotly.express as px from transformers import pipeline from datasets import Dataset from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration # Load the LLaMA-based model with RAG @st.cache(allow_output_mutation=True) def load_rag_model(): retriever = RagRetriever.from_pretrained("facebook/rag-token-base", index_name="custom") tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-base") model = RagSequenceForGeneration.from_pretrained("facebook/rag-token-base", retriever=retriever) return model, tokenizer, retriever model, tokenizer, retriever = load_rag_model() # Title of the app st.title("Interactive Insights Chatbot with LLaMA + RAG") # Step 1: Upload prompt.json file prompt_file = st.file_uploader("Upload your prompt.json file", type=["json"]) if prompt_file: prompt_data = json.load(prompt_file) st.write("Prompt JSON loaded successfully!") # Step 2: Upload CSV file csv_file = st.file_uploader("Upload your CSV file", type=["csv"]) if csv_file: data = pd.read_csv(csv_file) st.write("Data Preview:") st.dataframe(data.head()) # Convert the CSV data to a Hugging Face Dataset for retrieval dataset = Dataset.from_pandas(data) # Step 3: Natural language prompt input user_prompt = st.text_input("Enter your natural language prompt:") # Step 4: Process the user prompt and generate insights using LLaMA + RAG if user_prompt and csv_file: st.write(f"Processing your prompt: '{user_prompt}'") # Tokenize the prompt for LLaMA + RAG inputs = tokenizer(user_prompt, return_tensors="pt") # Perform retrieval-augmented generation (RAG) by retrieving data from the dataset and generating the response generated = model.generate(input_ids=inputs['input_ids'], num_return_sequences=1, num_beams=2) # Decode the output from the LLaMA + RAG model output = tokenizer.batch_decode(generated, skip_special_tokens=True) st.write(f"Insights generated: {output[0]}") # Example: if the prompt asks for a plot (like "show sales over time") if "plot sales" in user_prompt.lower(): # Create a bar chart (you can customize based on the prompt) fig = px.bar(data, x='Date', y='Sales', title="Sales Over Time") st.plotly_chart(fig) else: st.write("No recognized visual request in the prompt.")