xinfyxinfy commited on
Commit
8de2029
·
1 Parent(s): d15e62f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -12
app.py CHANGED
@@ -75,7 +75,7 @@ elif task == "TCR\u03B2-Peptide":
75
 
76
 
77
  ##################### ML predict function
78
- # @st.cache_data
79
  def predict_on_batch_output(dataset,shorttask,group):
80
 
81
  if dataset == 'MCPAS':
@@ -92,6 +92,7 @@ def predict_on_batch_output(dataset,shorttask,group):
92
  model = load_model('models/mcpas/bestmodel_alphabetapeptide.hdf5',compile=False)
93
  #predict_on_batch
94
  output = model.predict_on_batch([alpha_np, beta_np, pep_np])
 
95
  elif dataset=='mcpas' and shorttask=='abpm':
96
  #load data
97
  alpha, beta, pep, mhc = group
@@ -100,6 +101,7 @@ def predict_on_batch_output(dataset,shorttask,group):
100
  model = load_model('models/mcpas/bestmodel_alphabetaptptidemhc.hdf5',compile=False)
101
  #predict_on_batch
102
  output = model.predict_on_batch([alpha_np, beta_np, pep_np, mhc_np])
 
103
  elif dataset=='mcpas' and shorttask=='ap':
104
  #load data
105
  alpha, pep, = group
@@ -108,6 +110,7 @@ def predict_on_batch_output(dataset,shorttask,group):
108
  model = load_model('models/mcpas/bestmodel_alphapeptide.hdf5',compile=False)
109
  #predict_on_batch
110
  output = model.predict_on_batch([alpha_np,pep_np])
 
111
  elif dataset=='mcpas' and shorttask=='bp':
112
  #load data
113
  beta, pep = group
@@ -116,6 +119,7 @@ def predict_on_batch_output(dataset,shorttask,group):
116
  model = load_model('models/mcpas/bestmodel_betapeptide.hdf5',compile=False)
117
  #predict_on_batch
118
  output = model.predict_on_batch([beta_np, pep_np])
 
119
  elif dataset=='mcpas' and shorttask=='apm':
120
  #load data
121
  alpha, pep, mhc = group
@@ -124,6 +128,7 @@ def predict_on_batch_output(dataset,shorttask,group):
124
  model = load_model('models/mcpas/bestmodel_alphapeptidemhc.hdf5',compile=False)
125
  #predict_on_batch
126
  output = model.predict_on_batch([alpha_np, pep_np, mhc_np])
 
127
  elif dataset=='mcpas' and shorttask=='bpm':
128
  #load data
129
  beta, pep, mhc = group
@@ -132,6 +137,7 @@ def predict_on_batch_output(dataset,shorttask,group):
132
  model = load_model('models/mcpas/bestmodel_betapeptidemhc.hdf5',compile=False)
133
  #predict_on_batch
134
  output = model.predict_on_batch([beta_np, pep_np, mhc_np])
 
135
  elif dataset=='vdjdb' and shorttask=='abp':
136
  #load data
137
  alpha, beta, pep = group
@@ -140,6 +146,7 @@ def predict_on_batch_output(dataset,shorttask,group):
140
  model = load_model('models/vdjdb/bestmodel_alphabetapeptide.hdf5',compile=False)
141
  #predict_on_batch
142
  output = model.predict_on_batch([alpha_np, beta_np, pep_np])
 
143
  elif dataset=='vdjdb' and shorttask=='abpm':
144
  #load data
145
  alpha, beta, pep, mhc = group
@@ -148,7 +155,8 @@ def predict_on_batch_output(dataset,shorttask,group):
148
  model = load_model('models/vdjdb/bestmodel_alphabetapeptidemhc.hdf5',compile=False)
149
  #predict_on_batch
150
  output = model.predict_on_batch([alpha_np, beta_np, pep_np, mhc_np])
151
- elif dataset=='vdjdb' and shorttask=='ap':
 
152
  #load data
153
  alpha, pep, = group
154
  alpha_np, pep_np, = np.load(alpha), np.load(pep)
@@ -156,6 +164,7 @@ def predict_on_batch_output(dataset,shorttask,group):
156
  model = load_model('models/vdjdb/bestmodel_alphapeptide.hdf5',compile=False)
157
  #predict_on_batch
158
  output = model.predict_on_batch([alpha_np, pep_np])
 
159
  elif dataset=='vdjdb' and shorttask=='bp':
160
  #load data
161
  beta, pep = group
@@ -164,6 +173,7 @@ def predict_on_batch_output(dataset,shorttask,group):
164
  model = load_model('models/vdjdb/bestmodel_betapeptide.hdf5',compile=False)
165
  #predict_on_batch
166
  output = model.predict_on_batch([beta_np, pep_np])
 
167
  elif dataset=='vdjdb' and shorttask=='apm':
168
  #load data
169
  alpha, pep, mhc = group
@@ -172,6 +182,7 @@ def predict_on_batch_output(dataset,shorttask,group):
172
  model = load_model('models/vdjdb/bestmodel_alphapeptidemhc.hdf5',compile=False)
173
  #predict_on_batch
174
  output = model.predict_on_batch([alpha_np, pep_np, mhc_np])
 
175
  elif dataset=='vdjdb' and shorttask=='bpm':
176
  #load data
177
  beta, pep, mhc = group
@@ -186,7 +197,7 @@ def predict_on_batch_output(dataset,shorttask,group):
186
  val = np.squeeze(output)
187
  return val
188
 
189
- # @st.cache_data
190
  def convert_df(df):
191
  # IMPORTANT: Cache the conversion to prevent computation on every rerun
192
  return df.to_csv().encode('utf-8')
@@ -196,7 +207,9 @@ def convert_df(df):
196
  if st.button('Submit'):
197
  # with st.spinner('Wait for it...'):
198
  # time.sleep(0.5)
199
- # res = predict_on_batch_output(dataset,shorttask,group)
 
 
200
  # st.write("Binding Probabilities")
201
  # st.dataframe((np.round(res, 4)))
202
  # csv = convert_df(pd.DataFrame(np.round(res, 4), columns=['output']))
@@ -206,19 +219,24 @@ if st.button('Submit'):
206
  with st.spinner('Calculating ...'):
207
  time.sleep(0.5)
208
  st.write("Binding Probabilities")
209
- st.dataframe((np.round(res, 4)), use_container_width=500, height=500)
210
- csv = convert_df(pd.DataFrame(np.round(res, 4), columns=['output']))
 
 
 
 
 
211
  st.download_button(label="Download Predictions",data=csv,file_name='tcresm_predictions.csv', mime='text/csv')
212
  except:
213
- st.error('Please ensure you have uploaded the files and chosen the correct model before pressing the Submit button', icon="🚨")
214
 
215
 
216
 
217
- # if st.button("Clear All"):
218
- # # Clear values from *all* all in-memory and on-disk data caches:
219
- # # i.e. clear values from both square and cube
220
- # st.cache.clear()
221
 
222
 
223
 
224
- st.caption('Developed By: Shashank Yadav : shashank[at]arizona.edu', unsafe_allow_html=True)
 
75
 
76
 
77
  ##################### ML predict function
78
+ @st.cache_data
79
  def predict_on_batch_output(dataset,shorttask,group):
80
 
81
  if dataset == 'MCPAS':
 
92
  model = load_model('models/mcpas/bestmodel_alphabetapeptide.hdf5',compile=False)
93
  #predict_on_batch
94
  output = model.predict_on_batch([alpha_np, beta_np, pep_np])
95
+
96
  elif dataset=='mcpas' and shorttask=='abpm':
97
  #load data
98
  alpha, beta, pep, mhc = group
 
101
  model = load_model('models/mcpas/bestmodel_alphabetaptptidemhc.hdf5',compile=False)
102
  #predict_on_batch
103
  output = model.predict_on_batch([alpha_np, beta_np, pep_np, mhc_np])
104
+
105
  elif dataset=='mcpas' and shorttask=='ap':
106
  #load data
107
  alpha, pep, = group
 
110
  model = load_model('models/mcpas/bestmodel_alphapeptide.hdf5',compile=False)
111
  #predict_on_batch
112
  output = model.predict_on_batch([alpha_np,pep_np])
113
+
114
  elif dataset=='mcpas' and shorttask=='bp':
115
  #load data
116
  beta, pep = group
 
119
  model = load_model('models/mcpas/bestmodel_betapeptide.hdf5',compile=False)
120
  #predict_on_batch
121
  output = model.predict_on_batch([beta_np, pep_np])
122
+
123
  elif dataset=='mcpas' and shorttask=='apm':
124
  #load data
125
  alpha, pep, mhc = group
 
128
  model = load_model('models/mcpas/bestmodel_alphapeptidemhc.hdf5',compile=False)
129
  #predict_on_batch
130
  output = model.predict_on_batch([alpha_np, pep_np, mhc_np])
131
+
132
  elif dataset=='mcpas' and shorttask=='bpm':
133
  #load data
134
  beta, pep, mhc = group
 
137
  model = load_model('models/mcpas/bestmodel_betapeptidemhc.hdf5',compile=False)
138
  #predict_on_batch
139
  output = model.predict_on_batch([beta_np, pep_np, mhc_np])
140
+
141
  elif dataset=='vdjdb' and shorttask=='abp':
142
  #load data
143
  alpha, beta, pep = group
 
146
  model = load_model('models/vdjdb/bestmodel_alphabetapeptide.hdf5',compile=False)
147
  #predict_on_batch
148
  output = model.predict_on_batch([alpha_np, beta_np, pep_np])
149
+
150
  elif dataset=='vdjdb' and shorttask=='abpm':
151
  #load data
152
  alpha, beta, pep, mhc = group
 
155
  model = load_model('models/vdjdb/bestmodel_alphabetapeptidemhc.hdf5',compile=False)
156
  #predict_on_batch
157
  output = model.predict_on_batch([alpha_np, beta_np, pep_np, mhc_np])
158
+
159
+ elif dataset=='vdjdb' and shorttask=='ap':
160
  #load data
161
  alpha, pep, = group
162
  alpha_np, pep_np, = np.load(alpha), np.load(pep)
 
164
  model = load_model('models/vdjdb/bestmodel_alphapeptide.hdf5',compile=False)
165
  #predict_on_batch
166
  output = model.predict_on_batch([alpha_np, pep_np])
167
+
168
  elif dataset=='vdjdb' and shorttask=='bp':
169
  #load data
170
  beta, pep = group
 
173
  model = load_model('models/vdjdb/bestmodel_betapeptide.hdf5',compile=False)
174
  #predict_on_batch
175
  output = model.predict_on_batch([beta_np, pep_np])
176
+
177
  elif dataset=='vdjdb' and shorttask=='apm':
178
  #load data
179
  alpha, pep, mhc = group
 
182
  model = load_model('models/vdjdb/bestmodel_alphapeptidemhc.hdf5',compile=False)
183
  #predict_on_batch
184
  output = model.predict_on_batch([alpha_np, pep_np, mhc_np])
185
+
186
  elif dataset=='vdjdb' and shorttask=='bpm':
187
  #load data
188
  beta, pep, mhc = group
 
197
  val = np.squeeze(output)
198
  return val
199
 
200
+ @st.cache_data
201
  def convert_df(df):
202
  # IMPORTANT: Cache the conversion to prevent computation on every rerun
203
  return df.to_csv().encode('utf-8')
 
207
  if st.button('Submit'):
208
  # with st.spinner('Wait for it...'):
209
  # time.sleep(0.5)
210
+ # res = predict_on_batch_output(dataset,shorttask,group).flatten()
211
+ # # res = predict_output(dataset,shorttask,group)
212
+ # print(type(res), res.shape)
213
  # st.write("Binding Probabilities")
214
  # st.dataframe((np.round(res, 4)))
215
  # csv = convert_df(pd.DataFrame(np.round(res, 4), columns=['output']))
 
219
  with st.spinner('Calculating ...'):
220
  time.sleep(0.5)
221
  st.write("Binding Probabilities")
222
+ # st.dataframe(['Sample Number','Output'], use_container_width=500, height=500)
223
+ val_df = pd.DataFrame({'Sample Number': [item+1 for item in range(0,res.shape[0])], 'Binding Probability': res.tolist()})
224
+ val_df = val_df.set_index(['Sample Number'])
225
+ # st.dataframe(np.round(res, 4), use_container_width=500, height=500)
226
+ st.dataframe(val_df, use_container_width=500, height=500)
227
+ # csv = convert_df(pd.DataFrame(np.round(res, 4), columns=['output']))
228
+ csv = convert_df(val_df)
229
  st.download_button(label="Download Predictions",data=csv,file_name='tcresm_predictions.csv', mime='text/csv')
230
  except:
231
+ st.error('Please ensure you have uploaded the files before pressing the Submit button', icon="🚨")
232
 
233
 
234
 
235
+ if st.button("Clear All"):
236
+ # Clear values from *all* all in-memory and on-disk data caches:
237
+ # i.e. clear values from both square and cube
238
+ st.cache_data.clear()
239
 
240
 
241
 
242
+ st.caption('Developed By: Shashank Yadav : shashank[at]arizona.edu', unsafe_allow_html=True)