Update app.py
Browse files
app.py
CHANGED
@@ -7,10 +7,11 @@ import torch
|
|
7 |
import itertools
|
8 |
from torch.utils.data import DataLoader
|
9 |
from transformers import AutoTokenizer
|
|
|
10 |
|
11 |
sys.path.append("scripts/")
|
12 |
from foldseek_util import get_struc_seq
|
13 |
-
from utils import seed_everything
|
14 |
from models import PLTNUM_PreTrainedModel
|
15 |
from datasets_ import PLTNUMDataset
|
16 |
|
@@ -26,6 +27,7 @@ class Config:
|
|
26 |
self.task = "classification"
|
27 |
self.sequence_col = "sequence"
|
28 |
self.seed = 42
|
|
|
29 |
|
30 |
|
31 |
|
@@ -142,6 +144,71 @@ def predict(cfg, sequences):
|
|
142 |
return predictions, [1 if x > 0.5 else 0 for x in predictions]
|
143 |
|
144 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
145 |
# Gradio Interface
|
146 |
with gr.Blocks() as demo:
|
147 |
gr.Markdown(
|
@@ -184,6 +251,17 @@ with gr.Blocks() as demo:
|
|
184 |
outputs=prediction_output,
|
185 |
)
|
186 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
187 |
with gr.TabItem("Enter Protein Sequence"):
|
188 |
gr.Markdown("### Enter the protein sequence:")
|
189 |
sequence = gr.Textbox(
|
|
|
7 |
import itertools
|
8 |
from torch.utils.data import DataLoader
|
9 |
from transformers import AutoTokenizer
|
10 |
+
import shap
|
11 |
|
12 |
sys.path.append("scripts/")
|
13 |
from foldseek_util import get_struc_seq
|
14 |
+
from utils import seed_everything, save_pickle
|
15 |
from models import PLTNUM_PreTrainedModel
|
16 |
from datasets_ import PLTNUMDataset
|
17 |
|
|
|
27 |
self.task = "classification"
|
28 |
self.sequence_col = "sequence"
|
29 |
self.seed = 42
|
30 |
+
self.max_evals = 10
|
31 |
|
32 |
|
33 |
|
|
|
144 |
return predictions, [1 if x > 0.5 else 0 for x in predictions]
|
145 |
|
146 |
|
147 |
+
|
148 |
+
def calculate_shap_values_with_pdb(model_choice, organism_choice, pdb_files, cfg=Config()):
|
149 |
+
input_sequences = []
|
150 |
+
|
151 |
+
for pdb_file in pdb_files:
|
152 |
+
pdb_path = pdb_file.name
|
153 |
+
os.system("chmod 777 bin/foldseek")
|
154 |
+
sequences = get_foldseek_seq(pdb_path)
|
155 |
+
sequence = sequences[2] if model_choice == "SaProt" else sequences[0]
|
156 |
+
input_sequences.append(sequence)
|
157 |
+
|
158 |
+
shap_values = calculate_shap_values_core(model_choice, organism_choice, input_sequences, cfg)
|
159 |
+
|
160 |
+
output_path = "/tmp/shap_values.pkl"
|
161 |
+
save_pickle(
|
162 |
+
output_path, shap_values
|
163 |
+
)
|
164 |
+
|
165 |
+
return output_path
|
166 |
+
|
167 |
+
|
168 |
+
def calculate_shap_fn(texts, model, cfg):
|
169 |
+
if len(texts) == 1:
|
170 |
+
texts = texts[0]
|
171 |
+
else:
|
172 |
+
texts = texts.tolist()
|
173 |
+
|
174 |
+
inputs = cfg.tokenizer(
|
175 |
+
texts,
|
176 |
+
return_tensors="pt",
|
177 |
+
padding=True,
|
178 |
+
truncation=True,
|
179 |
+
max_length=cfg.max_length,
|
180 |
+
)
|
181 |
+
inputs = {k: v.to(cfg.device) for k, v in inputs.items()}
|
182 |
+
with torch.no_grad():
|
183 |
+
outputs = model(inputs)
|
184 |
+
outputs = torch.sigmoid(outputs).detach().cpu().numpy()
|
185 |
+
return outputs
|
186 |
+
|
187 |
+
|
188 |
+
def calculate_shap_values_core(model_choice, organism_choice, sequences, cfg=Config()):
|
189 |
+
cfg.device = "cuda" if torch.cuda.is_available() else "cpu"
|
190 |
+
|
191 |
+
seed_everything(cfg.seed)
|
192 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
193 |
+
cfg.model_path, padding_side=cfg.padding_side
|
194 |
+
)
|
195 |
+
cfg.tokenizer = tokenizer
|
196 |
+
|
197 |
+
model = PLTNUM_PreTrainedModel.from_pretrained(cfg.model_path, cfg=cfg).to(cfg.device)
|
198 |
+
model.eval()
|
199 |
+
|
200 |
+
# build an explainer using a token masker
|
201 |
+
explainer = shap.Explainer(lambda x: calculate_shap_fn(x, model, cfg), cfg.tokenizer)
|
202 |
+
|
203 |
+
shap_values = explainer(
|
204 |
+
sequences,
|
205 |
+
batch_size=cfg.batch_size,
|
206 |
+
max_evals=cfg.max_evals,
|
207 |
+
)
|
208 |
+
|
209 |
+
return shap_values
|
210 |
+
|
211 |
+
|
212 |
# Gradio Interface
|
213 |
with gr.Blocks() as demo:
|
214 |
gr.Markdown(
|
|
|
251 |
outputs=prediction_output,
|
252 |
)
|
253 |
|
254 |
+
calculate_shap_values_button = gr.Button("Calculate SHAP Values")
|
255 |
+
shap_values_output = gr.File(
|
256 |
+
label="Download SHAP Values"
|
257 |
+
)
|
258 |
+
calculate_shap_values_button.click(
|
259 |
+
fn=calculate_shap_values_with_pdb,
|
260 |
+
inputs=[model_choice, organism_choice, pdb_files],
|
261 |
+
outputs=shap_values_output,
|
262 |
+
)
|
263 |
+
|
264 |
+
|
265 |
with gr.TabItem("Enter Protein Sequence"):
|
266 |
gr.Markdown("### Enter the protein sequence:")
|
267 |
sequence = gr.Textbox(
|