DrishtiSharma commited on
Commit
3dc0491
·
verified ·
1 Parent(s): 0625cfa

Create llm_not_gen.py

Browse files
Files changed (1) hide show
  1. mylab/llm_not_gen.py +179 -0
mylab/llm_not_gen.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import plotly.express as px
4
+ from datasets import load_dataset
5
+ from pandasai import Agent
6
+ from langchain_community.embeddings.openai import OpenAIEmbeddings
7
+ from langchain_community.vectorstores import FAISS
8
+ from langchain_openai import ChatOpenAI
9
+ from langchain.chains import RetrievalQA
10
+ from langchain.schema import Document
11
+ import os
12
+ import logging
13
+
14
+ # Configure logging
15
+ logging.basicConfig(level=logging.DEBUG)
16
+ logger = logging.getLogger(__name__)
17
+
18
+ # Fetch API keys from environment variables
19
+ api_key = os.getenv("OPENAI_API_KEY")
20
+ pandasai_api_key = os.getenv("PANDASAI_API_KEY")
21
+
22
+ # Check for missing keys and raise specific errors
23
+ missing_keys = []
24
+ if not api_key:
25
+ missing_keys.append("OPENAI_API_KEY")
26
+ if not pandasai_api_key:
27
+ missing_keys.append("PANDASAI_API_KEY")
28
+
29
+ if missing_keys:
30
+ missing_keys_str = ", ".join(missing_keys)
31
+ raise EnvironmentError(
32
+ f"The following API key(s) are missing: {missing_keys_str}. Please set them in the environment."
33
+ )
34
+
35
+ # Title of the app
36
+ st.title("Data Analyzer")
37
+
38
+ # Function to load datasets into session
39
+ def load_dataset_into_session():
40
+ input_option = st.radio(
41
+ "Select Dataset Input:",
42
+ ["Use Repo Directory Dataset", "Use Hugging Face Dataset", "Upload CSV File"],
43
+ )
44
+
45
+ # Option 1: Load dataset from the repo directory
46
+ if input_option == "Use Repo Directory Dataset":
47
+ file_path = "./source/test.csv"
48
+ if st.button("Load Repo Dataset"):
49
+ try:
50
+ st.session_state.df = pd.read_csv(file_path)
51
+ st.success(f"File loaded successfully from '{file_path}'!")
52
+ st.dataframe(st.session_state.df.head(10))
53
+ except Exception as e:
54
+ st.error(f"Error loading dataset from the repo directory: {e}")
55
+ logger.error(f"Error loading dataset from repo directory: {e}")
56
+
57
+ # Option 2: Load dataset from Hugging Face
58
+ elif input_option == "Use Hugging Face Dataset":
59
+ dataset_name = st.text_input(
60
+ "Enter Hugging Face Dataset Name:", value="HUPD/hupd"
61
+ )
62
+ if st.button("Load Hugging Face Dataset"):
63
+ try:
64
+ dataset = load_dataset(dataset_name, split="train", trust_remote_code=True)
65
+ # Convert Hugging Face dataset to Pandas DataFrame
66
+ if hasattr(dataset, "to_pandas"):
67
+ st.session_state.df = dataset.to_pandas()
68
+ else:
69
+ st.session_state.df = pd.DataFrame(dataset)
70
+ st.success(f"Hugging Face Dataset '{dataset_name}' loaded successfully!")
71
+ st.dataframe(st.session_state.df.head(10))
72
+ except Exception as e:
73
+ st.error(f"Error loading Hugging Face dataset: {e}")
74
+ logger.error(f"Error loading Hugging Face dataset: {e}")
75
+
76
+ # Option 3: Upload CSV File
77
+ elif input_option == "Upload CSV File":
78
+ uploaded_file = st.file_uploader("Upload a CSV File:", type=["csv"])
79
+ if uploaded_file:
80
+ try:
81
+ st.session_state.df = pd.read_csv(uploaded_file)
82
+ st.success("File uploaded successfully!")
83
+ st.dataframe(st.session_state.df.head(10))
84
+ except Exception as e:
85
+ st.error(f"Error reading uploaded file: {e}")
86
+ logger.error(f"Error reading uploaded file: {e}")
87
+
88
+ # Ensure session state for the DataFrame
89
+ if "df" not in st.session_state:
90
+ st.session_state.df = None
91
+
92
+ # Load dataset into session
93
+ load_dataset_into_session()
94
+
95
+ # Check if a dataset is loaded
96
+ if st.session_state.df is not None:
97
+ df = st.session_state.df
98
+ try:
99
+ # Initialize PandasAI Agent
100
+ agent = Agent(df)
101
+
102
+ # Convert DataFrame to documents for RAG
103
+ documents = [
104
+ Document(
105
+ page_content=", ".join(
106
+ [f"{col}: {row[col]}" for col in df.columns if pd.notnull(row[col])]
107
+ ),
108
+ metadata={"index": index},
109
+ )
110
+ for index, row in df.iterrows()
111
+ ]
112
+
113
+ # Set up RAG
114
+ embeddings = OpenAIEmbeddings()
115
+ vectorstore = FAISS.from_documents(documents, embeddings)
116
+ retriever = vectorstore.as_retriever()
117
+ qa_chain = RetrievalQA.from_chain_type(
118
+ llm=ChatOpenAI(),
119
+ chain_type="stuff",
120
+ retriever=retriever,
121
+ )
122
+
123
+ # Create tabs
124
+ tab1, tab2, tab3 = st.tabs(
125
+ ["PandasAI Analysis", "RAG Q&A", "Data Visualization"]
126
+ )
127
+
128
+ # Tab 1: PandasAI Analysis
129
+ with tab1:
130
+ st.header("PandasAI Analysis")
131
+ pandas_question = st.text_input("Ask a question about the data (PandasAI):")
132
+ if pandas_question:
133
+ try:
134
+ result = agent.chat(pandas_question)
135
+ st.write("PandasAI Answer:", result)
136
+ except Exception as e:
137
+ st.error(f"Error during PandasAI Analysis: {e}")
138
+
139
+ # Tab 2: RAG Q&A
140
+ with tab2:
141
+ st.header("RAG Q&A")
142
+ rag_question = st.text_input("Ask a question about the data (RAG):")
143
+ if rag_question:
144
+ try:
145
+ result = qa_chain.run(rag_question)
146
+ st.write("RAG Answer:", result)
147
+ except Exception as e:
148
+ st.error(f"Error during RAG Q&A: {e}")
149
+
150
+ # Tab 3: Data Visualization
151
+ with tab3:
152
+ st.header("Data Visualization")
153
+ viz_question = st.text_input(
154
+ "What kind of graph would you like to create? (e.g., 'Show a scatter plot of salary vs experience')"
155
+ )
156
+ if viz_question:
157
+ try:
158
+ result = agent.chat(viz_question)
159
+
160
+ # Extract Python code for visualization
161
+ import re
162
+
163
+ code_pattern = r"```python\n(.*?)\n```"
164
+ code_match = re.search(code_pattern, result, re.DOTALL)
165
+
166
+ if code_match:
167
+ viz_code = code_match.group(1)
168
+ # Replace matplotlib (plt) code with Plotly (px)
169
+ viz_code = viz_code.replace("plt.", "px.")
170
+ exec(viz_code) # Execute the visualization code
171
+ st.plotly_chart(fig)
172
+ else:
173
+ st.warning("Could not generate a graph. Try a different query.")
174
+ except Exception as e:
175
+ st.error(f"Error during Data Visualization: {e}")
176
+ except Exception as e:
177
+ st.error(f"An error occurred during processing: {e}")
178
+ else:
179
+ st.info("Please load a dataset to start analysis.")