GenBIChatbot / app.py
Ari
Create app.py
3eb59a4 verified
raw
history blame
2.41 kB
import streamlit as st
import pandas as pd
import json
import os
import plotly.express as px
from transformers import pipeline
from datasets import Dataset
from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration
# Load the LLaMA-based model with RAG
@st.cache(allow_output_mutation=True)
def load_rag_model():
retriever = RagRetriever.from_pretrained("facebook/rag-token-base", index_name="custom")
tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-base")
model = RagSequenceForGeneration.from_pretrained("facebook/rag-token-base", retriever=retriever)
return model, tokenizer, retriever
model, tokenizer, retriever = load_rag_model()
# Title of the app
st.title("Interactive Insights Chatbot with LLaMA + RAG")
# Step 1: Upload prompt.json file
prompt_file = st.file_uploader("Upload your prompt.json file", type=["json"])
if prompt_file:
prompt_data = json.load(prompt_file)
st.write("Prompt JSON loaded successfully!")
# Step 2: Upload CSV file
csv_file = st.file_uploader("Upload your CSV file", type=["csv"])
if csv_file:
data = pd.read_csv(csv_file)
st.write("Data Preview:")
st.dataframe(data.head())
# Convert the CSV data to a Hugging Face Dataset for retrieval
dataset = Dataset.from_pandas(data)
# Step 3: Natural language prompt input
user_prompt = st.text_input("Enter your natural language prompt:")
# Step 4: Process the user prompt and generate insights using LLaMA + RAG
if user_prompt and csv_file:
st.write(f"Processing your prompt: '{user_prompt}'")
# Tokenize the prompt for LLaMA + RAG
inputs = tokenizer(user_prompt, return_tensors="pt")
# Perform retrieval-augmented generation (RAG) by retrieving data from the dataset and generating the response
generated = model.generate(input_ids=inputs['input_ids'], num_return_sequences=1, num_beams=2)
# Decode the output from the LLaMA + RAG model
output = tokenizer.batch_decode(generated, skip_special_tokens=True)
st.write(f"Insights generated: {output[0]}")
# Example: if the prompt asks for a plot (like "show sales over time")
if "plot sales" in user_prompt.lower():
# Create a bar chart (you can customize based on the prompt)
fig = px.bar(data, x='Date', y='Sales', title="Sales Over Time")
st.plotly_chart(fig)
else:
st.write("No recognized visual request in the prompt.")