DrishtiSharma commited on
Commit
b193f65
·
verified ·
1 Parent(s): ae28a57

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +111 -0
app.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from pandasai.llm.openai import OpenAI
3
+ from dotenv import load_dotenv
4
+ import os
5
+ import pandas as pd
6
+ from pandasai import PandasAI
7
+ from datasets import load_dataset
8
+ import time
9
+
10
+
11
+ openai_api_key = os.getenv("OPENAI_API_KEY")
12
+
13
+ def chat_with_csv(df, prompt):
14
+ llm = OpenAI(api_token=openai_api_key)
15
+ pandas_ai = PandasAI(llm)
16
+ result = pandas_ai.run(df, prompt=prompt)
17
+ return result
18
+
19
+ def load_huggingface_dataset(dataset_name):
20
+ progress_bar = st.progress(0)
21
+ try:
22
+ progress_bar.progress(10)
23
+ dataset = load_dataset(dataset_name, name="sample", split="train", trust_remote_code=True, uniform_split=True)
24
+ progress_bar.progress(50)
25
+ if hasattr(dataset, "to_pandas"):
26
+ df = dataset.to_pandas()
27
+ else:
28
+ df = pd.DataFrame(dataset)
29
+ progress_bar.progress(100)
30
+ return df
31
+ except Exception as e:
32
+ progress_bar.progress(0)
33
+ raise e
34
+
35
+ def load_uploaded_csv(uploaded_file):
36
+ progress_bar = st.progress(0)
37
+ try:
38
+ progress_bar.progress(10)
39
+ time.sleep(1)
40
+ progress_bar.progress(50)
41
+ df = pd.read_csv(uploaded_file)
42
+ progress_bar.progress(100)
43
+ return df
44
+ except Exception as e:
45
+ progress_bar.progress(0)
46
+ raise e
47
+
48
+ def load_dataset_into_session():
49
+ input_option = st.radio(
50
+ "Select Dataset Input:",
51
+ ["Use Repo Directory Dataset", "Use Hugging Face Dataset", "Upload CSV File"],
52
+ index=1,
53
+ horizontal=True
54
+ )
55
+
56
+ if input_option == "Use Repo Directory Dataset":
57
+ file_path = "./source/test.csv"
58
+ if st.button("Load Dataset"):
59
+ try:
60
+ with st.spinner("Loading dataset from the repo directory..."):
61
+ st.session_state.df = pd.read_csv(file_path)
62
+ st.success(f"File loaded successfully from '{file_path}'!")
63
+ except Exception as e:
64
+ st.error(f"Error loading dataset from the repo directory: {e}")
65
+
66
+ elif input_option == "Use Hugging Face Dataset":
67
+ dataset_name = st.text_input("Enter Hugging Face Dataset Name:", value="HUPD/hupd")
68
+ if st.button("Load Dataset"):
69
+ try:
70
+ st.session_state.df = load_huggingface_dataset(dataset_name)
71
+ st.success(f"Hugging Face Dataset '{dataset_name}' loaded successfully!")
72
+ except Exception as e:
73
+ st.error(f"Error loading Hugging Face dataset: {e}")
74
+
75
+ elif input_option == "Upload CSV File":
76
+ uploaded_file = st.file_uploader("Upload a CSV File:", type=["csv"])
77
+ if uploaded_file:
78
+ try:
79
+ st.session_state.df = load_uploaded_csv(uploaded_file)
80
+ st.success("File uploaded successfully!")
81
+ except Exception as e:
82
+ st.error(f"Error reading uploaded file: {e}")
83
+
84
+ # Streamlit app main
85
+ st.set_page_config(layout='wide')
86
+ st.title("ChatCSV powered by LLM")
87
+
88
+ # Ensure session state for the dataframe
89
+ if "df" not in st.session_state:
90
+ st.session_state.df = pd.DataFrame() # Initialize with an empty dataframe
91
+
92
+ st.header("Load Your Dataset")
93
+ load_dataset_into_session()
94
+
95
+ if not st.session_state.df.empty:
96
+ st.subheader("Dataset Preview")
97
+ st.dataframe(st.session_state.df, use_container_width=True)
98
+
99
+ st.subheader("Chat with Your Dataset")
100
+ user_query = st.text_area("Enter your query:")
101
+
102
+ if st.button("Run Query"):
103
+ if user_query.strip():
104
+ with st.spinner("Processing your query..."):
105
+ try:
106
+ result = chat_with_csv(st.session_state.df, user_query)
107
+ st.success(result)
108
+ except Exception as e:
109
+ st.error(f"Error processing your query: {e}")
110
+ else:
111
+ st.warning("Please enter a query before running.")