DrishtiSharma commited on
Commit
d7bf121
·
verified ·
1 Parent(s): 9e2b4b7

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +124 -0
app.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import plotly.express as px
4
+ from pandasai import Agent
5
+ from langchain_community.embeddings.openai import OpenAIEmbeddings
6
+ from langchain_community.vectorstores import FAISS
7
+ from langchain_openai import ChatOpenAI
8
+ from langchain.chains import RetrievalQA
9
+ from langchain.schema import Document
10
+ from datasets import load_dataset
11
+ import os
12
+
13
+ # Title
14
+ st.title("PandasAI Data Analysis Tool with RAG")
15
+
16
+ # Fetch API keys from environment variables
17
+ api_key = os.getenv("OPENAI_API_KEY")
18
+ pandasai_api_key = os.getenv("PANDASAI_API_KEY")
19
+
20
+ # Dataset selection
21
+ st.sidebar.header("Dataset Input Options")
22
+ input_option = st.sidebar.radio("Select Dataset Input:", ["Use Hugging Face Dataset", "Upload CSV File"])
23
+
24
+ # Initialize session state for the dataframe
25
+ if "df" not in st.session_state:
26
+ st.session_state.df = None
27
+
28
+ # Dataset loading logic
29
+ if input_option == "Use Hugging Face Dataset":
30
+ dataset_name = st.sidebar.text_input("Enter Hugging Face Dataset Name:", value="HUPD/hupd")
31
+ if st.sidebar.button("Load Dataset"):
32
+ try:
33
+ dataset = load_dataset(dataset_name, name="sample", split="train", trust_remote_code=True)
34
+ st.session_state.df = pd.DataFrame(dataset)
35
+ st.sidebar.success(f"Dataset '{dataset_name}' loaded successfully!")
36
+ except Exception as e:
37
+ st.sidebar.error(f"Error loading dataset: {e}")
38
+ elif input_option == "Upload CSV File":
39
+ uploaded_file = st.sidebar.file_uploader("Upload CSV File:", type=["csv"])
40
+ if uploaded_file:
41
+ try:
42
+ st.session_state.df = pd.read_csv(uploaded_file)
43
+ st.sidebar.success("File uploaded successfully!")
44
+ except Exception as e:
45
+ st.sidebar.error(f"Error loading file: {e}")
46
+
47
+ # Show the loaded dataframe preview
48
+ if st.session_state.df is not None:
49
+ st.subheader("Dataset Preview")
50
+ st.dataframe(st.session_state.df.head(10))
51
+
52
+ # Set up PandasAI Agent
53
+ agent = Agent(st.session_state.df)
54
+
55
+ # Convert DataFrame to documents
56
+ documents = [
57
+ Document(
58
+ page_content=", ".join([f"{col}: {row[col]}" for col in st.session_state.df.columns]),
59
+ metadata={"index": index}
60
+ )
61
+ for index, row in st.session_state.df.iterrows()
62
+ ]
63
+
64
+ # Set up RAG
65
+ embeddings = OpenAIEmbeddings()
66
+ vectorstore = FAISS.from_documents(documents, embeddings)
67
+ retriever = vectorstore.as_retriever()
68
+ qa_chain = RetrievalQA.from_chain_type(
69
+ llm=ChatOpenAI(),
70
+ chain_type="stuff",
71
+ retriever=retriever
72
+ )
73
+
74
+ # Create tabs for different functionality
75
+ tab1, tab2, tab3 = st.tabs(["PandasAI Analysis", "RAG Q&A", "Data Visualization"])
76
+
77
+ with tab1:
78
+ st.header("Data Analysis with PandasAI")
79
+ pandas_question = st.text_input("Ask a question about your data (PandasAI):")
80
+ if pandas_question:
81
+ result = agent.chat(pandas_question)
82
+ st.write("PandasAI Answer:", result)
83
+
84
+ with tab2:
85
+ st.header("Q&A with RAG")
86
+ rag_question = st.text_input("Ask a question about your data (RAG):")
87
+ if rag_question:
88
+ result = qa_chain.run(rag_question)
89
+ st.write("RAG Answer:", result)
90
+
91
+ with tab3:
92
+ st.header("Data Visualization")
93
+ viz_question = st.text_input("What kind of graph would you like to see? (e.g., 'Show a scatter plot of salary vs experience')")
94
+ if viz_question:
95
+ try:
96
+ result = agent.chat(viz_question)
97
+
98
+ # Convert the PandasAI result into executable code
99
+ import re
100
+ code_pattern = r'```python\n(.*?)\n```'
101
+ code_match = re.search(code_pattern, result, re.DOTALL)
102
+
103
+ if code_match:
104
+ viz_code = code_match.group(1)
105
+ # Modify the code to use 'px' instead of 'plt'
106
+ viz_code = viz_code.replace('plt.', 'px.')
107
+ viz_code = viz_code.replace('plt.show()', 'fig = px.scatter(df, x=x, y=y)')
108
+
109
+ # Execute the code and display the graph
110
+ exec(viz_code)
111
+ st.plotly_chart(fig)
112
+ else:
113
+ st.write("Failed to generate a graph. Please try asking differently.")
114
+ except Exception as e:
115
+ st.write(f"An error occurred: {str(e)}")
116
+ st.write("Please try rephrasing your question.")
117
+ else:
118
+ st.warning("No dataset loaded. Please select a dataset input option from the sidebar.")
119
+
120
+ # Error handling for missing API keys
121
+ if not api_key:
122
+ st.error("Missing OpenAI API Key in environment variables.")
123
+ if not pandasai_api_key:
124
+ st.error("Missing PandasAI API Key in environment variables.")