eaglelandsonce commited on
Commit
39cc7e2
·
verified ·
1 Parent(s): dd06c5f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -44
app.py CHANGED
@@ -12,7 +12,10 @@ scaler = StandardScaler()
12
  label_encoder = LabelEncoder()
13
  feature_columns = None # To store feature columns from the training data
14
 
15
- # Streamlit App
 
 
 
16
  def main():
17
  global feature_columns
18
 
@@ -21,49 +24,56 @@ def main():
21
 
22
  # Upload training data
23
  uploaded_file = st.file_uploader("Upload a CSV file for training", type="csv")
24
- if uploaded_file is not None:
 
 
 
25
  data = pd.read_csv(uploaded_file)
26
- st.write("Training Dataset Preview:", data.head())
27
 
28
- # Check for Treatment column in training data
29
- if 'Treatment' not in data.columns:
30
- st.error("The training data must contain a 'Treatment' column.")
31
- return
32
 
33
- # Prepare Data
34
- X, y, input_dim, num_classes, feature_columns = preprocess_training_data(data)
35
 
36
- # Model Parameters
37
- hidden_dim = st.slider("Hidden Layer Dimension", 10, 100, 50)
38
- learning_rate = st.number_input("Learning Rate", 0.0001, 0.1, 0.001)
39
- epochs = st.number_input("Epochs", 1, 100, 20)
40
 
41
- # Model training
42
- if st.button("Train Model"):
43
- model, loss_curve = train_model(X, y, input_dim, hidden_dim, num_classes, learning_rate, epochs)
44
- plot_loss_curve(loss_curve)
45
 
46
  # Upload data for prediction
47
  st.write("Upload new data without the 'Treatment' column for prediction.")
48
  new_data_file = st.file_uploader("Upload new CSV file for prediction", type="csv")
49
- if new_data_file is not None:
50
- if 'model' in locals() and feature_columns is not None:
51
- new_data = pd.read_csv(new_data_file)
52
-
53
- # Align columns to match training data
54
- new_data_aligned = align_columns(new_data, feature_columns)
 
 
 
 
 
 
 
55
 
56
- if new_data_aligned is not None:
57
- predictions = predict_treatment(new_data_aligned, model)
58
-
59
- # Display Predictions in an Output Box
60
- st.subheader("Predicted Treatment Outcomes")
61
- prediction_output = "\n".join([f"Patient {i+1}: {pred}" for i, pred in enumerate(predictions)])
62
- st.text_area("Prediction Results", prediction_output, height=200)
63
- else:
64
- st.error("Unable to align prediction data to the training feature columns.")
65
  else:
66
- st.error("Please train the model first before predicting on new data.")
 
 
67
 
68
  def preprocess_training_data(data):
69
  global scaler, label_encoder
@@ -149,17 +159,6 @@ def plot_loss_curve(loss_curve):
149
  st.pyplot(plt)
150
 
151
  def predict_treatment(new_data, model, batch_size=32):
152
- """
153
- Predict treatment outcomes for new data using the trained model.
154
-
155
- Args:
156
- - new_data (pd.DataFrame): The new dataset without a 'Treatment' column.
157
- - model (torch.nn.Module): The trained PyTorch model.
158
- - batch_size (int): Size of data batches for predictions (optional).
159
-
160
- Returns:
161
- - List of predicted outcomes in the original label format.
162
- """
163
  model.eval()
164
  predictions = []
165
 
 
12
  label_encoder = LabelEncoder()
13
  feature_columns = None # To store feature columns from the training data
14
 
15
+ # Preload default files
16
+ DEFAULT_TRAIN_FILE = "patientdata.csv"
17
+ DEFAULT_PREDICT_FILE = "synthetic_breast_cancer_notreatmentcolumn.csv"
18
+
19
  def main():
20
  global feature_columns
21
 
 
24
 
25
  # Upload training data
26
  uploaded_file = st.file_uploader("Upload a CSV file for training", type="csv")
27
+ if uploaded_file is None:
28
+ st.write("Using default training data.")
29
+ data = pd.read_csv(DEFAULT_TRAIN_FILE)
30
+ else:
31
  data = pd.read_csv(uploaded_file)
32
+ st.write("Training Dataset Preview:", data.head())
33
 
34
+ # Check for Treatment column in training data
35
+ if 'Treatment' not in data.columns:
36
+ st.error("The training data must contain a 'Treatment' column.")
37
+ return
38
 
39
+ # Prepare Data
40
+ X, y, input_dim, num_classes, feature_columns = preprocess_training_data(data)
41
 
42
+ # Model Parameters
43
+ hidden_dim = st.slider("Hidden Layer Dimension", 10, 100, 50)
44
+ learning_rate = st.number_input("Learning Rate", 0.0001, 0.1, 0.001)
45
+ epochs = st.number_input("Epochs", 1, 100, 20)
46
 
47
+ # Model training
48
+ if st.button("Train Model"):
49
+ model, loss_curve = train_model(X, y, input_dim, hidden_dim, num_classes, learning_rate, epochs)
50
+ plot_loss_curve(loss_curve)
51
 
52
  # Upload data for prediction
53
  st.write("Upload new data without the 'Treatment' column for prediction.")
54
  new_data_file = st.file_uploader("Upload new CSV file for prediction", type="csv")
55
+ if new_data_file is None:
56
+ st.write("Using default prediction data.")
57
+ new_data = pd.read_csv(DEFAULT_PREDICT_FILE)
58
+ else:
59
+ new_data = pd.read_csv(new_data_file)
60
+ st.write("Prediction Dataset Preview:", new_data.head())
61
+
62
+ if 'model' in locals() and feature_columns is not None:
63
+ # Align columns to match training data
64
+ new_data_aligned = align_columns(new_data, feature_columns)
65
+
66
+ if new_data_aligned is not None:
67
+ predictions = predict_treatment(new_data_aligned, model)
68
 
69
+ # Display Predictions in an Output Box
70
+ st.subheader("Predicted Treatment Outcomes")
71
+ prediction_output = "\n".join([f"Patient {i+1}: {pred}" for i, pred in enumerate(predictions)])
72
+ st.text_area("Prediction Results", prediction_output, height=200)
 
 
 
 
 
73
  else:
74
+ st.error("Unable to align prediction data to the training feature columns.")
75
+ else:
76
+ st.warning("Please train the model first before predicting on new data.")
77
 
78
  def preprocess_training_data(data):
79
  global scaler, label_encoder
 
159
  st.pyplot(plt)
160
 
161
  def predict_treatment(new_data, model, batch_size=32):
 
 
 
 
 
 
 
 
 
 
 
162
  model.eval()
163
  predictions = []
164