shivi commited on
Commit
054d60a
1 Parent(s): 684811b

minor changes

Browse files
Files changed (1) hide show
  1. utils/predict.py +4 -4
utils/predict.py CHANGED
@@ -19,9 +19,9 @@ def batch_predict(input_data):
19
 
20
  input_data.to_csv(input_data_file, index=None, header=None)
21
 
22
- prod_dataset = get_dataset_from_csv(input_data_file, shuffle=True)
23
 
24
- pred = model.predict(prod_dataset)
25
 
26
  for prediction, actual_gt in zip(pred, input_data['income_level'].values.tolist()):
27
  y_pred_prob = round(prediction.flatten()[0] * 100, 2)
@@ -96,10 +96,10 @@ def user_input_predict(age, wage, cap_gains, cap_losses, dividends, num_persons,
96
  input_data_file = "input_data.csv"
97
 
98
  input_df.to_csv(input_data_file, index=None, header=None)
99
- prod_dataset = get_dataset_from_csv(input_data_file, shuffle=True)
100
 
101
  labels = ['Income greater than 50000',"Income less than 50000"]
102
- prediction = model.predict(prod_dataset)
103
  y_pred_prob = round(prediction[0].flatten()[0],5)
104
  y_not_prob = round(1-prediction[0].flatten()[0],3)
105
 
 
19
 
20
  input_data.to_csv(input_data_file, index=None, header=None)
21
 
22
+ input_dataset = get_dataset_from_csv(input_data_file, shuffle=True)
23
 
24
+ pred = model.predict(input_dataset)
25
 
26
  for prediction, actual_gt in zip(pred, input_data['income_level'].values.tolist()):
27
  y_pred_prob = round(prediction.flatten()[0] * 100, 2)
 
96
  input_data_file = "input_data.csv"
97
 
98
  input_df.to_csv(input_data_file, index=None, header=None)
99
+ input_dataset = get_dataset_from_csv(input_data_file, shuffle=True)
100
 
101
  labels = ['Income greater than 50000',"Income less than 50000"]
102
+ prediction = model.predict(input_dataset)
103
  y_pred_prob = round(prediction[0].flatten()[0],5)
104
  y_not_prob = round(1-prediction[0].flatten()[0],3)
105