Redmind commited on
Commit
a298c6f
·
verified ·
1 Parent(s): 0befa89

Upload chat.py

Browse files
Files changed (1) hide show
  1. chat.py +285 -0
chat.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ from langchain.agents import initialize_agent, Tool
4
+ from langchain_community.vectorstores import FAISS
5
+ from langchain_openai import OpenAIEmbeddings
6
+ from langchain_openai import ChatOpenAI
7
+ from langchain_core.prompts import PromptTemplate
8
+ from langchain_core.output_parsers import StrOutputParser
9
+ import pandas as pd
10
+ from pandasai.llm.openai import OpenAI
11
+ from pandasai import SmartDataframe
12
+
13
+ # Initialize a blank DataFrame as a global variable
14
+ global_df = pd.DataFrame()
15
+
16
+
17
+ class ChatHandler:
18
+ def __init__(self, vector_db_path, open_api_key, grok_api_key,db_final):
19
+ self.vector_db_path = vector_db_path
20
+ self.openai_embeddings = OpenAIEmbeddings(api_key=open_api_key)
21
+ self.llm_openai = ChatOpenAI(model_name="gpt-4", api_key=open_api_key, max_tokens=500, temperature=0.2)
22
+ self.grok_api_key = grok_api_key
23
+ self.openai_api_key = open_api_key
24
+ self.sql_db = db_final
25
+
26
+ def _load_documents_from_vector_db(self, query):
27
+ """Fetch relevant documents from the vector database."""
28
+ results = []
29
+
30
+ # Debug: Print the query being processed
31
+ print(f"Processing query: {query}")
32
+
33
+ for root, dirs, files in os.walk(self.vector_db_path):
34
+ print(f"Searching in directory: {root}") # Debug: Current directory being processed
35
+ for dir in dirs:
36
+ index_path = os.path.join(root, dir, "index.faiss")
37
+
38
+ # Debug: Check if FAISS index exists
39
+ if os.path.exists(index_path):
40
+ print(f"Found FAISS index at: {index_path}")
41
+
42
+ # Load the FAISS vector store
43
+ try:
44
+ vector_store = FAISS.load_local(
45
+ os.path.join(root, dir),
46
+ self.openai_embeddings,
47
+ allow_dangerous_deserialization=True
48
+ )
49
+ print(f"Loaded FAISS vector store from: {os.path.join(root, dir)}")
50
+ except Exception as e:
51
+ print(f"Error loading FAISS store: {e}")
52
+ continue
53
+
54
+ # Perform similarity search
55
+ try:
56
+
57
+ response_with_scores = vector_store.similarity_search_with_relevance_scores(query, k=100)
58
+ #print(response_with_scores)
59
+ print(f"Similarity search returned {len(response_with_scores)} results.")
60
+
61
+ filtered_results = [
62
+ (doc, score) for doc, score in response_with_scores
63
+ if score is not None and score > 0.7 #and material_name.lower() in doc.page_content.lower() # Check material name in document
64
+ ]
65
+ print(f"Filtered results: {filtered_results}")
66
+ response_with_scores = filtered_results
67
+ # Debug: Print each document and score
68
+ for doc, score in response_with_scores:
69
+ print(f"Document: {doc.page_content[:100]}... Score: {score}")
70
+
71
+ results.extend([(doc.page_content, score) for doc, score in response_with_scores])
72
+ except Exception as e:
73
+ print(f"Error during similarity search: {e}")
74
+
75
+ # Sort and return results
76
+ sorted_results = [doc for doc, score in sorted(results, key=lambda x: -x[1])]
77
+ print(f"Total results after sorting: {len(sorted_results)}")
78
+
79
+ return sorted_results
80
+
81
+ def _load_schema_from_database(self, query):
82
+
83
+ """
84
+ Fetch database schema, generate a SQL query from the user's question, and execute it.
85
+ """
86
+ try:
87
+ # Fetch the schema
88
+ schema = self.sql_db.get_table_info()
89
+
90
+ # Define the prompt template
91
+ template_query_generation = """
92
+ Based on the table schema below, write a SQL query that would answer the user's question.
93
+ Only write the SQL query without explanations.
94
+
95
+ Schema:
96
+ {schema}
97
+
98
+ Question: {question}
99
+
100
+ SQL Query:
101
+ """
102
+ prompt = PromptTemplate(
103
+ input_variables=["schema", "question"],
104
+ template=template_query_generation
105
+ )
106
+
107
+ # Initialize the language modelgpt-4
108
+ llm = ChatOpenAI(model_name="gpt-4", api_key=self.openai_api_key, max_tokens=500, temperature=0.2)
109
+
110
+ # Create the runnable sequence
111
+ chain = prompt | llm | StrOutputParser()
112
+
113
+ # Generate the SQL query
114
+ sql_query = chain.invoke({"schema": schema, "question": query}).strip()
115
+
116
+ if not sql_query:
117
+ return "Could not generate an SQL query for your question."
118
+
119
+ # Execute the SQL query
120
+ try:
121
+ result = self.sql_db.run(sql_query)
122
+ print(f"SQL query executed successfully. Result: {result}")
123
+
124
+ except Exception as e:
125
+ return f"Error executing SQL query: {str(e)}"
126
+
127
+ # If no result, return an appropriate message
128
+ if not result:
129
+ return "Query executed, but no results were returned."
130
+
131
+ # Return the result
132
+ return result
133
+
134
+ except Exception as e:
135
+ return f"Error fetching schema details or processing query: {str(e)}"
136
+
137
+ def answer_question(self, query, visual_query):
138
+ global global_df
139
+ """Determine whether to use vector database or SQL database for the query."""
140
+ tools = [
141
+ # {
142
+ # "name": "Document Vector Store",
143
+ # "function": lambda q: "\n".join(self._load_documents_from_vector_db(q)),
144
+ # "description": """Search within the uploaded documents stored in the vector database.
145
+ # Display the response as a combination of response summary and the response data in the form of table.
146
+ # If the user requested comparison between two or more years, data should be shown for all the years. (For example, if the user requested from 2020 to 2024, then display the output table with the columns [Month, Material value in 2020, Material value in 2021, Material value in 2022, Material value in 2023, Material value in 2024]) so that the records will be displayed for all the months from Jaunary to December across the years.
147
+ # display the material quantity in blue colour if it the 'Type' column value is 'actual'.
148
+ # display the Material Quanity in red colour if its value is 'predicted'.
149
+ # include the table data in the Final answer of agent executor invoke.""",
150
+ # },
151
+ {
152
+ "name": "Database Schema",
153
+ "function": lambda q: self._load_schema_from_database(q),
154
+ "description": """Search within the database schema and generate SQL-based responses.
155
+ The database has single table 'sarima_forecast_results' which contains the columns 'material_date', 'material_name', 'material_count', and 'type'. If the material name is given, frame the query in such a way that the material_name is not case-sensitive.
156
+ display the response as a combination of response summary and the response data in the form of table.
157
+ If the user requested comparison between two or more years or the user asks for the data for all years, data should be shown for all the years with month as first column and the years like 2020, 2021 etc as the adjacent columns. Do not show everything in the same column. (For example, if the user requested from 2020 to 2024, then display the output table with the columns [Month, Material value in 2020, Material value in 2020, Material value in 2021, Material value in 2022, Material value in 2023, Material value in 2024]) so that the records will be displayed for all the months from Jaunary to December across the years.
158
+ include the table data in the Final answer.""",
159
+ },
160
+ ]
161
+
162
+ agent_prompt = PromptTemplate(
163
+ input_variables=["input", "agent_scratchpad"],
164
+ template="""
165
+ You are a highly skilled AI assistant specializing in document analysis.
166
+ I have uploaded a document containing material demand forecasts with columns for 'date', 'Material Name', 'Material Quantity', and 'Type'.
167
+
168
+ The data includes historical demand information for various items.
169
+
170
+ 1. The uploaded document includes:
171
+ - **Date:** The date of demand entry.
172
+ - **Material Name:** The name of the material or equipment.
173
+ - **Material Quantity:** The number of units Utilized or forecasted.
174
+ - **Type:** Type contains actual or forecasted, actual represents the actual material utilized and forecasted represents the prediction by ai model.
175
+
176
+ 2. I may ask questions such as:
177
+ - Forecasting future demand for specific items.
178
+ - Analyzing trends or patterns for materials over time.
179
+ - Summarizing the highest or lowest demands within a specific date range.
180
+ - Comparing demand values between two or more items.
181
+
182
+ Your task:
183
+ - If the query relates to forecasting or involves the uploaded document, extract the necessary information from it
184
+ and provide precise, professional, and data-driven responses.
185
+
186
+ Make sure your answers are aligned with the uploaded document, depending on the context of the query.
187
+ display the response as mentioned in the tool description. display the output table whereever it is required.
188
+ include the table data in the Final answer if it is there.
189
+ Tools available to you:
190
+ {tools}
191
+
192
+ Input Question:
193
+ {input}
194
+
195
+ {agent_scratchpad}
196
+ """,
197
+ )
198
+
199
+ # Initialize the agent
200
+ agent = initialize_agent(
201
+ tools=[Tool(name=t["name"], func=t["function"], description=t["description"]) for t in tools],
202
+ llm=self.llm_openai,
203
+ agent="zero-shot-react-description",
204
+ verbose=True,
205
+ prompt=agent_prompt
206
+ )
207
+
208
+ try:
209
+ response = agent.invoke(query, handle_parsing_errors=True)
210
+ print(f"response:{response}")
211
+
212
+
213
+
214
+ if isinstance(response, dict) and "output" in response:
215
+ response = response["output"] # Extract and return only the output field
216
+ else:
217
+ response = response # Fallback if output field is not present
218
+ if visual_query is not None:
219
+ # Check if the response contains table-like formatting
220
+ if "|" in response and "---" in response:
221
+ print("Table data is present in the response.")
222
+ #convert table data into dataframe
223
+ # Extract table rows
224
+ table_pattern = r"\|.*\|"
225
+ import re
226
+ table_data = re.findall(table_pattern, response)
227
+ # Remove separator lines (like |---|---|)
228
+ filtered_data = [row for row in table_data if not re.match(r"\|\-+\|", row)]
229
+
230
+ # Split rows into columns
231
+ split_data = [row.strip('|').split('|') for row in filtered_data]
232
+
233
+ # Create DataFrame
234
+ columns = [col.strip() for col in split_data[0]] # First row is the header
235
+ data = [list(map(str.strip, row)) for row in split_data[1:]] # Remaining rows are data
236
+ global_df = pd.DataFrame(data, columns=columns)
237
+ # Function to convert datatypes
238
+ global_df = convert_column_types(global_df)
239
+ print(f"Dataframe created from response:\n{global_df}")
240
+ visual_response = create_visualization_csv(visual_query)
241
+
242
+
243
+ else:
244
+ print("No table data found in the response.")
245
+ visual_response = None
246
+ return response, visual_response
247
+ except Exception as e:
248
+ return f"Error while processing your query: {str(e)}", None
249
+
250
+
251
+
252
+ def create_visualization_csv(visual_query):
253
+
254
+ global_df
255
+ #import matplotlib
256
+ #matplotlib.use('TkAgg') # Replace with 'QtAgg' or 'MacOSX' if on macOS
257
+
258
+ visual_query = visual_query + " create chart with suitable x and y axis as user requested. use proper axis values. Do not miss any values. add legend in the chart. mention axis labels in the chart. mention only month name in date axis and not the numbers."
259
+
260
+
261
+ llm_chart = OpenAI()
262
+ #from pandasai import PandasAI
263
+ #pandas_ai = PandasAI(llm_chart, show_plots=False)
264
+ #pandas_ai = PandasAI(show_plots=False) # Avoids attempting to show plots
265
+
266
+ sdf = SmartDataframe(global_df, config={"llm": llm_chart})
267
+ llm_response = sdf.chat(visual_query)
268
+ if "no result" in llm_response:
269
+ return " There is a problem in generating the chart. Please try again ater some time."
270
+ return llm_response
271
+
272
+
273
+ def convert_column_types(df):
274
+ for col in df.columns:
275
+ # Try to convert to integer
276
+ if all(df[col].str.isdigit()):
277
+ df[col] = df[col].astype(int)
278
+ # Try to convert to datetime
279
+ else:
280
+ try:
281
+ df[col] = pd.to_datetime(df[col], format='%Y-%m-%d', errors='raise')
282
+ except ValueError:
283
+ # Leave as string if neither integer nor date
284
+ pass
285
+ return df