robertselvam commited on
Commit
0e3fc88
·
verified ·
1 Parent(s): a5cd68d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +423 -0
app.py ADDED
@@ -0,0 +1,423 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ from io import StringIO
3
+ import pandas as pd
4
+ import numpy as np
5
+ import xgboost as xgb
6
+ from math import sqrt
7
+ from sklearn.metrics import mean_squared_error
8
+ from sklearn.model_selection import train_test_split
9
+ import plotly.express as px
10
+ import logging
11
+
12
+ from datetime import datetime
13
+
14
+ import plotly.graph_objects as go
15
+ import numpy as np
16
+ import matplotlib.pyplot as plt
17
+ from matplotlib import pyplot
18
+ import whisper
19
+ from openai import AzureOpenAI
20
+ import json
21
+ import re
22
+ import gradio as gr
23
+
24
+ # Configure logging
25
+ logging.basicConfig(
26
+ filename='demand_forecasting.log', # You can adjust the log file name here
27
+ filemode='a',
28
+ format='[%(asctime)s] [%(levelname)s] [%(filename)s] [%(lineno)s:%(funcName)s()] %(message)s',
29
+ datefmt='%Y-%b-%d %H:%M:%S'
30
+ )
31
+ LOGGER = logging.getLogger(__name__)
32
+
33
+ log_level_env = 'INFO' # You can adjust the log level here
34
+ log_level_dict = {
35
+ 'DEBUG': logging.DEBUG,
36
+ 'INFO': logging.INFO,
37
+ 'WARNING': logging.WARNING,
38
+ 'ERROR': logging.ERROR,
39
+ 'CRITICAL': logging.CRITICAL
40
+ }
41
+ if log_level_env in log_level_dict:
42
+ log_level = log_level_dict[log_level_env]
43
+ else:
44
+ log_level = log_level_dict['INFO']
45
+ LOGGER.setLevel(log_level)
46
+
47
+ class DemandForecasting:
48
+ def __init__(self):
49
+ self.client = AzureOpenAI()
50
+ self.whisper_model = whisper.load_model("medium.en")
51
+
52
+
53
+ def get_column(self,train_csv_path: str):
54
+ # Load the training data from the specified CSV file
55
+ train_df = pd.read_csv(train_csv_path)
56
+
57
+ column_names = train_df.columns.tolist()
58
+ return column_names
59
+
60
+ def load_data(self, train_csv_path: str) -> pd.DataFrame:
61
+ """
62
+ Load training data from a CSV file.
63
+
64
+ Args:
65
+ train_csv_path (str): Path to the training CSV file.
66
+
67
+ Returns:
68
+ pd.DataFrame: DataFrame containing the training data.
69
+ """
70
+ try:
71
+ # Load the training data from the specified CSV file
72
+ train_df = pd.read_csv(train_csv_path)
73
+
74
+
75
+ # Return a tuple containing the training DataFrame
76
+ return train_df
77
+
78
+ except Exception as e:
79
+ # Log an error message if an exception occurs during data loading
80
+ LOGGER.error(f"Error loading data: {e}")
81
+
82
+ # Return None
83
+ return None
84
+
85
+
86
+ def find_date_column(self, df_data: pd.DataFrame, list_columns: list) -> str:
87
+ """
88
+ Find the column containing date information from the list of columns.
89
+
90
+ Args:
91
+ - df_data (pd.DataFrame): Input DataFrame.
92
+ - list_columns (list): List of column names to search for date information.
93
+
94
+ Returns:
95
+ - str: Name of the column containing date information.
96
+ """
97
+ for column in list_columns:
98
+ # Check if the column contains date-like values
99
+ try:
100
+ pd.to_datetime(df_data[column])
101
+ return column
102
+ except ValueError:
103
+ pass
104
+
105
+ # Return None if no date column is found
106
+ return None
107
+
108
+ def preprocess_data(self, df_data: pd.DataFrame, list_columns) -> pd.DataFrame:
109
+ """
110
+ Preprocess the input DataFrame.
111
+
112
+ Args:
113
+ - df_data (pd.DataFrame): Input DataFrame to preprocess.
114
+
115
+ Returns:
116
+ - pd.DataFrame: Preprocessed DataFrame.
117
+ """
118
+ try:
119
+ print(type(list_columns))
120
+ # Make a copy of the input DataFrame to avoid modifying the original data
121
+ df_data = df_data.copy()
122
+
123
+ list_columns.append(target_column)
124
+
125
+ # Drop columns not in list_columns
126
+ columns_to_drop = [col for col in df_data.columns if col not in list_columns]
127
+ df_data.drop(columns=columns_to_drop, inplace=True)
128
+
129
+ # Find the date column
130
+ date_column = self.find_date_column(df_data, list_columns)
131
+ if date_column is None:
132
+ raise ValueError("No date column found in the provided list of columns.")
133
+
134
+
135
+
136
+ # Parse date information
137
+ df_data[date_column] = pd.to_datetime(df_data[date_column]) # Convert 'date' column to datetime format
138
+ df_data['day'] = df_data[date_column].dt.day # Extract day of the month
139
+ df_data['month'] = df_data[date_column].dt.month # Extract month
140
+ df_data['year'] = df_data[date_column].dt.year # Extract year
141
+
142
+ # Cyclical Encoding for Months
143
+ df_data['month_sin'] = np.sin(2 * np.pi * df_data['month'] / 12) # Cyclical sine encoding for month
144
+ df_data['month_cos'] = np.cos(2 * np.pi * df_data['month'] / 12) # Cyclical cosine encoding for month
145
+
146
+ # Day of the Week
147
+ df_data['day_of_week'] = df_data[date_column].dt.weekday # Extract day of the week (0 = Monday, 6 = Sunday)
148
+
149
+ # Week of the Year
150
+ df_data['week_of_year'] = df_data[date_column].dt.isocalendar().week.astype(int) # Extract week of the year as integer
151
+
152
+ df_data.drop(columns=[date_column], inplace=True)
153
+
154
+ print("df_data", df_data)
155
+ return df_data
156
+
157
+ except Exception as e:
158
+ # Log an error message if an exception occurs during data preprocessing
159
+ LOGGER.error(f"Error preprocessing data: {e}")
160
+
161
+ # Return None in case of an error
162
+ return None
163
+
164
+ def train_model(self, train: pd.DataFrame, target_column, list_columns) -> tuple:
165
+ """
166
+ Train an XGBoost model using the provided training data.
167
+
168
+ Args:
169
+ - train (pd.DataFrame): DataFrame containing training data.
170
+
171
+ Returns:
172
+ - tuple: A tuple containing the trained model, true validation labels, and predicted validation labels.
173
+ """
174
+ try:
175
+
176
+ # Extract features and target variable
177
+ X = train.drop(columns=[target_column])
178
+ y = train[target_column]
179
+
180
+ # Cannot use cross validation because it will use future data
181
+ X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=333)
182
+
183
+ # Convert data into DMatrix format for XGBoost
184
+ dtrain = xgb.DMatrix(X_train, label=y_train)
185
+ dval = xgb.DMatrix(X_val, label=y_val)
186
+
187
+ # Parameters for XGBoost
188
+ param = {
189
+ 'max_depth': 9,
190
+ 'eta': 0.3,
191
+ 'objective': 'reg:squarederror'
192
+ }
193
+
194
+ num_round = 60
195
+
196
+ # Train the model
197
+ model_xgb = xgb.train(param, dtrain, num_round)
198
+
199
+ # Validate the model
200
+ y_val_pred = model_xgb.predict(dval) # Predict validation set labels
201
+
202
+ # Calculate mean squared error
203
+ mse = mean_squared_error(y_val, y_val_pred)
204
+
205
+ # Print validation RMSE
206
+ validation = f"Validation RMSE: {np.sqrt(mse)}"
207
+
208
+ # Return trained model, true validation labels, and predicted validation labels
209
+ return model_xgb, y_val, y_val_pred, validation
210
+
211
+ except Exception as e:
212
+ # Log an error message if an exception occurs during model training
213
+ LOGGER.error(f"Error training model: {e}")
214
+
215
+ # Return None for all outputs in case of an error
216
+ return None, None, None
217
+
218
+ def plot_evaluation_interactive(self, y_true: np.ndarray, y_pred: np.ndarray, title: str) -> None:
219
+ """
220
+ Plot interactive evaluation using Plotly.
221
+
222
+ Args:
223
+ - y_true (np.ndarray): True values.
224
+ - y_pred (np.ndarray): Predicted values.
225
+ - title (str): Title of the plot.
226
+ """
227
+ try:
228
+ # Create a scatter plot using Plotly
229
+ fig = px.scatter(x=y_true, y=y_pred, labels={'x': 'True Values', 'y': 'Predictions'}, title=title, color_discrete_map={'': 'purple'})
230
+ fig.show()
231
+ return fig
232
+
233
+ except Exception as e:
234
+ # Log an error message if an exception occurs during plot generation
235
+ LOGGER.error(f"Error plotting evaluation: {e}")
236
+
237
+
238
+ def predict_sales_for_date(self, input_data, model: xgb.Booster) -> float:
239
+ """
240
+ Predict the sales for a specific date using the trained model.
241
+
242
+ Args:
243
+ - date_input (str): Date for which sales prediction is needed (in 'YYYY-MM-DD' format).
244
+ - model (xgb.Booster): Trained XGBoost model.
245
+ - features (pd.DataFrame): DataFrame containing features for the date.
246
+
247
+ Returns:
248
+ - float: Predicted sales value.
249
+ """
250
+ try:
251
+ input_features = pd.DataFrame([input_data])
252
+
253
+ # Regular expression pattern for date in the format 'dd-mm-yyyy'
254
+ for key, value in input_data.items():
255
+ if isinstance(value, str) and re.match(r'\d{2}-\d{2}-\d{4}', value):
256
+ date_column = key
257
+
258
+ if date_column:
259
+ # # Assuming date_input is a datetime object
260
+ date_input = pd.to_datetime(input_features[date_column])
261
+
262
+ # Extract day of the month
263
+ input_features['day'] = date_input.dt.day
264
+
265
+ # Extract month
266
+ input_features['month'] = date_input.dt.month
267
+
268
+ # Extract year
269
+ input_features['year'] = date_input.dt.year
270
+
271
+ # Cyclical sine encoding for month
272
+ input_features['month_sin'] = np.sin(2 * np.pi * input_features['month'] / 12)
273
+
274
+ # Cyclical cosine encoding for month
275
+ input_features['month_cos'] = np.cos(2 * np.pi * input_features['month'] / 12)
276
+
277
+ # Extract day of the week (0 = Monday, 6 = Sunday)
278
+ input_features['day_of_week'] = date_input.dt.weekday
279
+
280
+ # Extract week of the year as integer
281
+ input_features['week_of_year'] = date_input.dt.isocalendar().week
282
+
283
+
284
+ input_features.drop(columns=[date_column], inplace=True)
285
+
286
+ # Convert input features to DMatrix format
287
+ dinput = xgb.DMatrix(input_features)
288
+
289
+ # Make predictions using the trained model
290
+ predicted_sales = model.predict(dinput)[0]
291
+
292
+ # Print the predicted sales value
293
+ predicted_result = f"""{input_data[str(date_column)]}Predicted Value Is {predicted_sales}"""
294
+ # Return the predicted sales value
295
+ return predicted_result
296
+
297
+ except Exception as e:
298
+ # Log an error message if an exception occurs during sales prediction
299
+ LOGGER.error(f"Error predicting sales: {e}")
300
+
301
+ # Return None in case of an error
302
+ return None
303
+
304
+ def audio_to_text(self, audio_path):
305
+ """
306
+ transcribe the audio to text.
307
+ """
308
+
309
+
310
+ result = self.whisper_model.transcribe(audio_path)
311
+ print("audio_to_text",result["text"])
312
+ return result["text"]
313
+
314
+
315
+ def parse_text(self, text, column_list):
316
+
317
+ # Define the prompt or input for the model
318
+ conversation =[{"role": "system", "content": ""},
319
+ {"role": "user", "content":f""" extract the {column_list}. al
320
+ l values should be intiger data type. if date in there the format is dd-mm-YYYY.
321
+ text```{text}```
322
+ return result should be in JSON format:
323
+
324
+ """
325
+ }]
326
+
327
+ # Generate a response from the GPT-3 model
328
+ chat_completion = self.client.chat.completions.create(
329
+ model = "GPT-3",
330
+ messages = conversation,
331
+ max_tokens=500,
332
+ temperature=0,
333
+ n=1,
334
+ stop=None,
335
+ )
336
+
337
+ # Extract the generated text from the API response
338
+ generated_text = chat_completion.choices[0].message.content
339
+
340
+ # Assuming jsonString is your JSON string
341
+ json_data = json.loads(generated_text)
342
+ print("parse_text",json_data)
343
+ return json_data
344
+
345
+ def main(self, train_csv_path: str, audio_path, target_column, column_list) -> None:
346
+ """
347
+ Main function to execute the demand forecasting pipeline.
348
+
349
+ Args:
350
+ - train_csv_path (str): Path to the training CSV file.
351
+ - date (str): Date for which sales prediction is needed (in 'YYYY-MM-DD' format).
352
+ """
353
+ try:
354
+
355
+
356
+ # Split the string by comma and convert it into a list
357
+ column_list = column_list.split(", ")
358
+
359
+ print("train_csv_path", train_csv_path)
360
+ print("audio_path", audio_path)
361
+ print("column_list", column_list)
362
+ print("target_column", target_column)
363
+
364
+ text = self.audio_to_text(audio_path)
365
+
366
+ input_data = self.parse_text(text, column_list)
367
+
368
+ #load data
369
+ train_data = self.load_data(train_csv_path)
370
+
371
+ #preprocess the train data
372
+ train_df = self.preprocess_data(train_data, column_list)
373
+
374
+ # Train model and get validation predictions
375
+ trained_model, y_val, y_val_pred, validation = self.train_model(train_df, target_column, column_list)
376
+
377
+ # Plot interactive evaluation for training
378
+ plot = self.plot_evaluation_interactive(y_val, y_val_pred, title='Validation Set Evaluation')
379
+
380
+ # Predict sales for the specified date using the trained model
381
+ predicted_value = self.predict_sales_for_date(input_data, trained_model)
382
+
383
+ return plot, predicted_value, validation
384
+
385
+ except Exception as e:
386
+ # Log an error message if an exception occurs in the main function
387
+ LOGGER.error(f"Error in main function: {e}")
388
+
389
+ def gradio_interface(self):
390
+ with gr.Blocks(css="style.css", theme="freddyaboulton/test-blue") as demo:
391
+
392
+ gr.HTML("""<center><h1 style="color:#fff">Demand Forecasting</h1></center>""")
393
+
394
+ with gr.Row():
395
+ with gr.Column(scale=0.50):
396
+ train_csv = gr.File(elem_classes="uploadbutton")
397
+ with gr.Column(scale=0.50):
398
+ column_list = gr.Textbox(label="Column List")
399
+
400
+ with gr.Row():
401
+ with gr.Column(scale=0.50):
402
+ audio_path = gr.Audio(sources=["microphone"], type="filepath")
403
+ with gr.Row():
404
+ with gr.Column(scale=0.50):
405
+ selected_column = gr.Textbox(label="Select column")
406
+ with gr.Column(scale=0.50):
407
+ target_column = gr.Textbox(label="target column")
408
+
409
+
410
+ with gr.Row():
411
+ validation = gr.Textbox(label="Validation")
412
+ predicted_result = gr.Textbox(label="Predicted Result")
413
+ plot = gr.Plot()
414
+
415
+ train_csv.upload(self.get_column, train_csv, column_list)
416
+ audio_path.stop_recording(self.main, [train_csv, audio_path, target_column, selected_column], [plot, predicted_result, validation])
417
+
418
+ demo.launch(debug=True)
419
+
420
+ if __name__ == "__main__":
421
+
422
+ demand = DemandForecasting()
423
+ demand.gradio_interface()