Anton Bushuiev commited on
Commit
21d055c
1 Parent(s): dfeea55

Adapt to download weights and new attention format

Browse files
Files changed (1) hide show
  1. app.py +8 -3
app.py CHANGED
@@ -18,7 +18,7 @@ from ppiref.extraction import PPIExtractor
18
  from ppiref.utils.ppi import PPIPath
19
  from ppiref.utils.residue import Residue
20
  from ppiformer.tasks.node import DDGPPIformer
21
- from ppiformer.utils.api import predict_ddg
22
  from ppiformer.utils.torch import fill_diagonal
23
  from ppiformer.definitions import PPIFORMER_WEIGHTS_DIR
24
 
@@ -100,7 +100,9 @@ def plot_3dmol(pdb_path, ppi_path, muts, attn):
100
 
101
  # Prepare attention coeffictients per residue (normalized sum of direct attention from mutated residues)
102
  attn = torch.nan_to_num(attn, nan=1e-10)
103
- attn_sub = attn[:, 0, :, 0, :, :, :] # models, layers, heads, tokens, tokens
 
 
104
  idx_mutated = torch.from_numpy(ppi_df.index[ppi_df['mutated']].to_numpy())
105
  attn_sub = fill_diagonal(attn_sub, 1e-10)
106
  attn_mutated = attn_sub[..., idx_mutated, :]
@@ -188,7 +190,7 @@ def plot_3dmol(pdb_path, ppi_path, muts, attn):
188
 
189
  $(document).ready(function () {
190
  let element = $("#container");
191
- let config = { backgroundColor: "black" };
192
  let viewer = $3Dmol.createViewer(element, config);
193
  viewer.addModel(pdb, "pdb");
194
  viewer.setStyle({"model": 0}, {"ray_opaque_background": "off"}, {"stick": {"color": "lightgrey", "opacity": 0.5}});
@@ -271,6 +273,9 @@ with app:
271
  )
272
  plot = gr.HTML()
273
 
 
 
 
274
  # Load models
275
  models = [
276
  DDGPPIformer.load_from_checkpoint(
 
18
  from ppiref.utils.ppi import PPIPath
19
  from ppiref.utils.residue import Residue
20
  from ppiformer.tasks.node import DDGPPIformer
21
+ from ppiformer.utils.api import download_weights, predict_ddg
22
  from ppiformer.utils.torch import fill_diagonal
23
  from ppiformer.definitions import PPIFORMER_WEIGHTS_DIR
24
 
 
100
 
101
  # Prepare attention coeffictients per residue (normalized sum of direct attention from mutated residues)
102
  attn = torch.nan_to_num(attn, nan=1e-10)
103
+ # attn_sub = attn[:, 0, :, 0, :, :, :] # models, layers, heads, tokens, tokens
104
+ # TODO Generalize to remove hardcoded 0 at dimension 1 correpsonding to useing attention for the 1st mutation
105
+ attn_sub = attn[:, 0, 0, :, 0, :, :, :] # models, layers, heads, tokens, tokens
106
  idx_mutated = torch.from_numpy(ppi_df.index[ppi_df['mutated']].to_numpy())
107
  attn_sub = fill_diagonal(attn_sub, 1e-10)
108
  attn_mutated = attn_sub[..., idx_mutated, :]
 
190
 
191
  $(document).ready(function () {
192
  let element = $("#container");
193
+ let config = { backgroundColor: "white" };
194
  let viewer = $3Dmol.createViewer(element, config);
195
  viewer.addModel(pdb, "pdb");
196
  viewer.setStyle({"model": 0}, {"ray_opaque_background": "off"}, {"stick": {"color": "lightgrey", "opacity": 0.5}});
 
273
  )
274
  plot = gr.HTML()
275
 
276
+ # Download weights from Zenodo
277
+ download_weights()
278
+
279
  # Load models
280
  models = [
281
  DDGPPIformer.load_from_checkpoint(