sagawa commited on
Commit
93d358b
·
verified ·
1 Parent(s): 33ff4f6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +79 -1
app.py CHANGED
@@ -7,10 +7,11 @@ import torch
7
  import itertools
8
  from torch.utils.data import DataLoader
9
  from transformers import AutoTokenizer
 
10
 
11
  sys.path.append("scripts/")
12
  from foldseek_util import get_struc_seq
13
- from utils import seed_everything
14
  from models import PLTNUM_PreTrainedModel
15
  from datasets_ import PLTNUMDataset
16
 
@@ -26,6 +27,7 @@ class Config:
26
  self.task = "classification"
27
  self.sequence_col = "sequence"
28
  self.seed = 42
 
29
 
30
 
31
 
@@ -142,6 +144,71 @@ def predict(cfg, sequences):
142
  return predictions, [1 if x > 0.5 else 0 for x in predictions]
143
 
144
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  # Gradio Interface
146
  with gr.Blocks() as demo:
147
  gr.Markdown(
@@ -184,6 +251,17 @@ with gr.Blocks() as demo:
184
  outputs=prediction_output,
185
  )
186
 
 
 
 
 
 
 
 
 
 
 
 
187
  with gr.TabItem("Enter Protein Sequence"):
188
  gr.Markdown("### Enter the protein sequence:")
189
  sequence = gr.Textbox(
 
7
  import itertools
8
  from torch.utils.data import DataLoader
9
  from transformers import AutoTokenizer
10
+ import shap
11
 
12
  sys.path.append("scripts/")
13
  from foldseek_util import get_struc_seq
14
+ from utils import seed_everything, save_pickle
15
  from models import PLTNUM_PreTrainedModel
16
  from datasets_ import PLTNUMDataset
17
 
 
27
  self.task = "classification"
28
  self.sequence_col = "sequence"
29
  self.seed = 42
30
+ self.max_evals = 10
31
 
32
 
33
 
 
144
  return predictions, [1 if x > 0.5 else 0 for x in predictions]
145
 
146
 
147
+
148
+ def calculate_shap_values_with_pdb(model_choice, organism_choice, pdb_files, cfg=Config()):
149
+ input_sequences = []
150
+
151
+ for pdb_file in pdb_files:
152
+ pdb_path = pdb_file.name
153
+ os.system("chmod 777 bin/foldseek")
154
+ sequences = get_foldseek_seq(pdb_path)
155
+ sequence = sequences[2] if model_choice == "SaProt" else sequences[0]
156
+ input_sequences.append(sequence)
157
+
158
+ shap_values = calculate_shap_values_core(model_choice, organism_choice, input_sequences, cfg)
159
+
160
+ output_path = "/tmp/shap_values.pkl"
161
+ save_pickle(
162
+ output_path, shap_values
163
+ )
164
+
165
+ return output_path
166
+
167
+
168
+ def calculate_shap_fn(texts, model, cfg):
169
+ if len(texts) == 1:
170
+ texts = texts[0]
171
+ else:
172
+ texts = texts.tolist()
173
+
174
+ inputs = cfg.tokenizer(
175
+ texts,
176
+ return_tensors="pt",
177
+ padding=True,
178
+ truncation=True,
179
+ max_length=cfg.max_length,
180
+ )
181
+ inputs = {k: v.to(cfg.device) for k, v in inputs.items()}
182
+ with torch.no_grad():
183
+ outputs = model(inputs)
184
+ outputs = torch.sigmoid(outputs).detach().cpu().numpy()
185
+ return outputs
186
+
187
+
188
+ def calculate_shap_values_core(model_choice, organism_choice, sequences, cfg=Config()):
189
+ cfg.device = "cuda" if torch.cuda.is_available() else "cpu"
190
+
191
+ seed_everything(cfg.seed)
192
+ tokenizer = AutoTokenizer.from_pretrained(
193
+ cfg.model_path, padding_side=cfg.padding_side
194
+ )
195
+ cfg.tokenizer = tokenizer
196
+
197
+ model = PLTNUM_PreTrainedModel.from_pretrained(cfg.model_path, cfg=cfg).to(cfg.device)
198
+ model.eval()
199
+
200
+ # build an explainer using a token masker
201
+ explainer = shap.Explainer(lambda x: calculate_shap_fn(x, model, cfg), cfg.tokenizer)
202
+
203
+ shap_values = explainer(
204
+ sequences,
205
+ batch_size=cfg.batch_size,
206
+ max_evals=cfg.max_evals,
207
+ )
208
+
209
+ return shap_values
210
+
211
+
212
  # Gradio Interface
213
  with gr.Blocks() as demo:
214
  gr.Markdown(
 
251
  outputs=prediction_output,
252
  )
253
 
254
+ calculate_shap_values_button = gr.Button("Calculate SHAP Values")
255
+ shap_values_output = gr.File(
256
+ label="Download SHAP Values"
257
+ )
258
+ calculate_shap_values_button.click(
259
+ fn=calculate_shap_values_with_pdb,
260
+ inputs=[model_choice, organism_choice, pdb_files],
261
+ outputs=shap_values_output,
262
+ )
263
+
264
+
265
  with gr.TabItem("Enter Protein Sequence"):
266
  gr.Markdown("### Enter the protein sequence:")
267
  sequence = gr.Textbox(