eaglelandsonce commited on
Commit
a2e802a
·
verified ·
1 Parent(s): f93ed07

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -129
app.py CHANGED
@@ -1,21 +1,5 @@
1
- import streamlit as st
2
- import pandas as pd
3
- import torch
4
- import torch.nn as nn
5
- import torch.optim as optim
6
- import matplotlib.pyplot as plt
7
- from sklearn.preprocessing import StandardScaler, LabelEncoder
8
- import numpy as np
9
-
10
- # Global scaler and label encoder for consistent preprocessing
11
- scaler = StandardScaler()
12
- label_encoder = LabelEncoder()
13
- feature_columns = None # To store feature columns from the training data
14
- model = None # Declare the model globally for predictions
15
-
16
- # Preload default files
17
- DEFAULT_TRAIN_FILE = "patientdata.csv"
18
- DEFAULT_PREDICT_FILE = "synthetic_breast_cancer_data_withColumn.csv"
19
 
20
  def main():
21
  global feature_columns, model
@@ -69,8 +53,8 @@ def main():
69
  st.error(f"Error during model training: {e}")
70
  return
71
 
72
- # Upload data for prediction
73
- st.write("Upload new data for prediction (ensure 'Treatment' column is removed if present).")
74
  new_data_file = st.file_uploader("Upload new CSV file for prediction", type="csv")
75
  if new_data_file is None:
76
  st.write("Using default prediction data.")
@@ -86,14 +70,17 @@ def main():
86
  st.error(f"Error loading uploaded prediction file: {e}")
87
  return
88
 
89
- # Drop 'Treatment' column if it exists
90
- if 'Treatment' in new_data.columns:
91
- st.warning("The 'Treatment' column is present in the prediction data and will be removed.")
92
- new_data = new_data.drop(columns=['Treatment'])
93
-
94
  st.write("Prediction Dataset Preview:")
95
  st.dataframe(new_data.head()) # Display new data
96
 
 
 
 
 
 
 
 
 
97
  if model is not None and feature_columns is not None:
98
  try:
99
  # Align columns to match training data
@@ -101,115 +88,35 @@ def main():
101
 
102
  if new_data_aligned is not None:
103
  predictions = predict_treatment(new_data_aligned, model)
104
-
105
- # Display Predictions in an Output Box
106
- st.subheader("Predicted Treatment Outcomes")
107
- prediction_output = "\n".join([f"Patient {i+1}: {pred}" for i, pred in enumerate(predictions)])
108
- st.text_area("Prediction Results", prediction_output, height=200)
 
 
 
109
  else:
110
  st.error("Unable to align prediction data to the training feature columns.")
111
  except Exception as e:
112
- st.error(f"Error during prediction: {e}")
113
  else:
114
- st.warning("Please train the model first before predicting on new data.")
115
-
116
- def preprocess_training_data(data):
117
- global scaler, label_encoder
118
-
119
- # Label encode the 'Treatment' target column
120
- data['Treatment'] = label_encoder.fit_transform(data['Treatment'])
121
- y = data['Treatment'].values
122
-
123
- # Encode and standardize feature columns
124
- X = data.drop('Treatment', axis=1)
125
- feature_columns = X.columns # Store feature columns for later alignment
126
- for col in X.select_dtypes(include=['object']).columns:
127
- X[col] = LabelEncoder().fit_transform(X[col])
128
-
129
- # Standardize features
130
- X = scaler.fit_transform(X)
131
-
132
- return torch.tensor(X, dtype=torch.float32), torch.tensor(y, dtype=torch.long), X.shape[1], len(np.unique(y)), feature_columns
133
-
134
- def align_columns(new_data, feature_columns):
135
- try:
136
- # Ensure the new data has the same columns as the training data
137
- missing_cols = set(feature_columns) - set(new_data.columns)
138
- extra_cols = set(new_data.columns) - set(feature_columns)
139
-
140
- # Remove any extra columns
141
- new_data = new_data.drop(columns=extra_cols)
142
-
143
- # Add missing columns with default value 0
144
- for col in missing_cols:
145
- new_data[col] = 0
146
-
147
- # Reorder columns to match the training data
148
- new_data = new_data[feature_columns]
149
-
150
- # Encode and standardize feature columns
151
- for col in new_data.select_dtypes(include=['object']).columns:
152
- new_data[col] = LabelEncoder().fit_transform(new_data[col])
153
-
154
- # Scale features
155
- new_data = scaler.transform(new_data)
156
-
157
- return torch.tensor(new_data, dtype=torch.float32)
158
- except Exception as e:
159
- st.error(f"Error aligning columns: {e}")
160
- return None
161
-
162
- def train_model(X, y, input_dim, hidden_dim, num_classes, learning_rate, epochs):
163
- class SimpleNN(nn.Module):
164
- def __init__(self, input_dim, hidden_dim, num_classes):
165
- super(SimpleNN, self).__init__()
166
- self.fc1 = nn.Linear(input_dim, hidden_dim)
167
- self.relu = nn.ReLU()
168
- self.fc2 = nn.Linear(hidden_dim, num_classes)
169
-
170
- def forward(self, x):
171
- x = self.fc1(x)
172
- x = self.relu(x)
173
- x = self.fc2(x)
174
- return x
175
-
176
- model = SimpleNN(input_dim, hidden_dim, num_classes)
177
- criterion = nn.CrossEntropyLoss()
178
- optimizer = optim.Adam(model.parameters(), lr=learning_rate)
179
-
180
- loss_curve = []
181
- for epoch in range(epochs):
182
- optimizer.zero_grad()
183
- outputs = model(X)
184
- loss = criterion(outputs, y)
185
- loss.backward()
186
- optimizer.step()
187
- loss_curve.append(loss.item())
188
-
189
- return model, loss_curve
190
-
191
- def plot_loss_curve(loss_curve):
192
- plt.figure()
193
- plt.plot(loss_curve, label="Training Loss")
194
- plt.xlabel("Epochs")
195
- plt.ylabel("Loss")
196
- plt.title("Loss Curve")
197
- plt.legend()
198
- plt.tight_layout() # Ensure layout is tight for Streamlit
199
  st.pyplot(plt)
200
 
201
- def predict_treatment(new_data, model, batch_size=32):
202
- model.eval()
203
- predictions = []
204
-
205
- with torch.no_grad():
206
- for i in range(0, new_data.size(0), batch_size):
207
- batch_data = new_data[i:i + batch_size]
208
- outputs = model(batch_data)
209
- _, batch_predictions = torch.max(outputs, 1)
210
- predictions.extend(batch_predictions.numpy())
211
-
212
- return label_encoder.inverse_transform(predictions)
213
-
214
  if __name__ == "__main__":
215
  main()
 
1
+ from sklearn.metrics import classification_report, confusion_matrix
2
+ import seaborn as sns # For confusion matrix heatmap
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
  def main():
5
  global feature_columns, model
 
53
  st.error(f"Error during model training: {e}")
54
  return
55
 
56
+ # Upload data for prediction and comparison
57
+ st.write("Upload new data for prediction and evaluation.")
58
  new_data_file = st.file_uploader("Upload new CSV file for prediction", type="csv")
59
  if new_data_file is None:
60
  st.write("Using default prediction data.")
 
70
  st.error(f"Error loading uploaded prediction file: {e}")
71
  return
72
 
 
 
 
 
 
73
  st.write("Prediction Dataset Preview:")
74
  st.dataframe(new_data.head()) # Display new data
75
 
76
+ if 'Treatment' not in new_data.columns:
77
+ st.error("The prediction file must contain a 'Treatment' column for evaluation.")
78
+ return
79
+
80
+ # Extract true labels and drop Treatment for prediction
81
+ true_labels = label_encoder.transform(new_data['Treatment'])
82
+ new_data = new_data.drop(columns=['Treatment'])
83
+
84
  if model is not None and feature_columns is not None:
85
  try:
86
  # Align columns to match training data
 
88
 
89
  if new_data_aligned is not None:
90
  predictions = predict_treatment(new_data_aligned, model)
91
+
92
+ # Evaluation Metrics
93
+ st.subheader("Model Evaluation Metrics")
94
+ classification_metrics(true_labels, predictions)
95
+
96
+ # Visualize Confusion Matrix
97
+ confusion_mat = confusion_matrix(true_labels, predictions)
98
+ plot_confusion_matrix(confusion_mat, label_encoder.classes_)
99
  else:
100
  st.error("Unable to align prediction data to the training feature columns.")
101
  except Exception as e:
102
+ st.error(f"Error during prediction or evaluation: {e}")
103
  else:
104
+ st.warning("Please train the model first before predicting and evaluating on new data.")
105
+
106
+ def classification_metrics(true_labels, predictions):
107
+ # Generate classification report
108
+ report = classification_report(true_labels, predictions, target_names=label_encoder.classes_, output_dict=True)
109
+ st.write("Classification Report:")
110
+ st.table(pd.DataFrame(report).transpose())
111
+
112
+ def plot_confusion_matrix(confusion_mat, classes):
113
+ # Plot confusion matrix
114
+ plt.figure(figsize=(8, 6))
115
+ sns.heatmap(confusion_mat, annot=True, fmt="d", cmap="Blues", xticklabels=classes, yticklabels=classes)
116
+ plt.xlabel("Predicted Labels")
117
+ plt.ylabel("True Labels")
118
+ plt.title("Confusion Matrix")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  st.pyplot(plt)
120
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  if __name__ == "__main__":
122
  main()