Anton Bushuiev commited on
Commit
98b3032
·
1 Parent(s): 0c10579

Improve layout and examples

Browse files
Files changed (2) hide show
  1. app.py +29 -12
  2. assets/readme-dimer-close-up.png +0 -0
app.py CHANGED
@@ -77,7 +77,7 @@ def process_inputs(inputs, temp_dir):
77
  return pdb_path, ppi_path, muts
78
 
79
 
80
- def plot_3dmol(pdb_path, ppi_path, muts, attn):
81
  # 3DMol.js adapted from https://huggingface.co/spaces/huhlim/cg2all/blob/main/app.py
82
 
83
  # Read PDB for 3Dmol.js
@@ -92,17 +92,14 @@ def plot_3dmol(pdb_path, ppi_path, muts, attn):
92
  # Read PPI to customize 3Dmol.js visualization
93
  ppi_df = PandasPdb().read_pdb(ppi_path).df['ATOM']
94
  ppi_df = ppi_df.groupby(list(Residue._fields)).apply(lambda df: df[df['atom_name'] == 'CA'].iloc[0]).reset_index(drop=True)
95
- chains = ppi_df['chain_id'].unique()
96
  ppi_df['id'] = ppi_df.apply(lambda row: ':'.join([row['residue_name'], row['chain_id'], str(row['residue_number']), row['insertion']]), axis=1)
97
  ppi_df['id'] = ppi_df['id'].apply(lambda x: x[:-1] if x[-1] == ':' else x)
98
- muts_id = sum([Mutation(mut).wt_to_graphein() for mut in muts], start=[]) # flatten ids of all sp muts
99
  ppi_df['mutated'] = ppi_df.apply(lambda row: row['id'] in muts_id, axis=1)
100
 
101
  # Prepare attention coeffictients per residue (normalized sum of direct attention from mutated residues)
102
  attn = torch.nan_to_num(attn, nan=1e-10)
103
- # attn_sub = attn[:, 0, :, 0, :, :, :] # models, layers, heads, tokens, tokens
104
- # TODO Generalize to remove hardcoded 0 at dimension 1 correpsonding to useing attention for the 1st mutation
105
- attn_sub = attn[:, 0, 0, :, 0, :, :, :] # models, layers, heads, tokens, tokens
106
  idx_mutated = torch.from_numpy(ppi_df.index[ppi_df['mutated']].to_numpy())
107
  attn_sub = fill_diagonal(attn_sub, 1e-10)
108
  attn_mutated = attn_sub[..., idx_mutated, :]
@@ -112,6 +109,8 @@ def plot_3dmol(pdb_path, ppi_path, muts, attn):
112
  attns_per_token += 1e-10
113
  ppi_df['attn'] = attns_per_token.numpy()
114
 
 
 
115
  # Customize 3Dmol.js visualization https://3dmol.csb.pitt.edu/doc/
116
  styles = []
117
  zoom_atoms = []
@@ -234,10 +233,28 @@ def predict(models, temp_dir, *inputs):
234
  return df, plot
235
 
236
 
237
- app = gr.Blocks()
238
  with app:
239
 
240
  # Input GUI
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
241
  with gr.Row():
242
  with gr.Column():
243
  gr.Markdown("## PPI structure")
@@ -253,12 +270,12 @@ with app:
253
 
254
  examples = gr.Examples(
255
  examples=[
256
- ["1BUI", "A,B,C", "SC16A;FC47A;SC16A,FC47A"],
257
- ["1KNE", "A,P", ';'.join([f"TP6{a}" for a in AMINO_ACID_CODES_1])],
258
- ["1C4Z", "A,B,C,D", "FA690A;KD100A"]
259
  ],
260
  inputs=[pdb_code, partners, muts],
261
- label="Examples (press line to fill inputs)"
262
  )
263
 
264
  # Predict GUI
@@ -295,4 +312,4 @@ with app:
295
  predict = partial(predict, models, temp_dir)
296
  predict_button.click(predict, inputs=inputs, outputs=outputs)
297
 
298
- app.launch()
 
77
  return pdb_path, ppi_path, muts
78
 
79
 
80
+ def plot_3dmol(pdb_path, ppi_path, muts, attn, mut_id=0):
81
  # 3DMol.js adapted from https://huggingface.co/spaces/huhlim/cg2all/blob/main/app.py
82
 
83
  # Read PDB for 3Dmol.js
 
92
  # Read PPI to customize 3Dmol.js visualization
93
  ppi_df = PandasPdb().read_pdb(ppi_path).df['ATOM']
94
  ppi_df = ppi_df.groupby(list(Residue._fields)).apply(lambda df: df[df['atom_name'] == 'CA'].iloc[0]).reset_index(drop=True)
 
95
  ppi_df['id'] = ppi_df.apply(lambda row: ':'.join([row['residue_name'], row['chain_id'], str(row['residue_number']), row['insertion']]), axis=1)
96
  ppi_df['id'] = ppi_df['id'].apply(lambda x: x[:-1] if x[-1] == ':' else x)
97
+ muts_id = Mutation(muts[mut_id]).wt_to_graphein() # flatten ids of all sp muts
98
  ppi_df['mutated'] = ppi_df.apply(lambda row: row['id'] in muts_id, axis=1)
99
 
100
  # Prepare attention coeffictients per residue (normalized sum of direct attention from mutated residues)
101
  attn = torch.nan_to_num(attn, nan=1e-10)
102
+ attn_sub = attn[:, mut_id, 0, :, 0, :, :, :] # models, layers, heads, tokens, tokens
 
 
103
  idx_mutated = torch.from_numpy(ppi_df.index[ppi_df['mutated']].to_numpy())
104
  attn_sub = fill_diagonal(attn_sub, 1e-10)
105
  attn_mutated = attn_sub[..., idx_mutated, :]
 
109
  attns_per_token += 1e-10
110
  ppi_df['attn'] = attns_per_token.numpy()
111
 
112
+ chains = ppi_df.sort_values('attn', ascending=False)['chain_id'].unique()
113
+
114
  # Customize 3Dmol.js visualization https://3dmol.csb.pitt.edu/doc/
115
  styles = []
116
  zoom_atoms = []
 
233
  return df, plot
234
 
235
 
236
+ app = gr.Blocks(theme=gr.themes.Default(primary_hue="green", secondary_hue="pink"))
237
  with app:
238
 
239
  # Input GUI
240
+ gr.Markdown(value="# PPIformer Web")
241
+ gr.Image("assets/readme-dimer-close-up.png")
242
+ gr.Markdown(value="""
243
+ [PPIformer](https://github.com/anton-bushuiev/PPIformer/tree/main) is a state-of-the-art predictor of the effects of mutations on protein-protein interactions (PPIs),
244
+ as quantified by the binding energy changes (ddG). The model was pre-trained on the [PPIRef](https://github.com/anton-bushuiev/PPIRef)
245
+ dataset via a coarse-grained structural masked modeling and fine-tuned on [SKEMPI v2.0](https://life.bsc.es/pid/skempi2) via log odds.
246
+ PPIformer was shown to successfully identify known favorable mutations of the [staphylokinase thrombolytic](https://pubmed.ncbi.nlm.nih.gov/10942387/)
247
+ and a [human antibody](https://www.pnas.org/doi/10.1073/pnas.2122954119) against the SARS-CoV-2 spike protein. Please see more details in [our paper](https://arxiv.org/abs/2310.18515).
248
+
249
+ To use PPIformer on your data, please specify the PPI structure (PDB code or file), interacting proteins of interest (chain codes in the file) and mutations
250
+ (semicolon-separated list or file with mutations in the [standard format](https://foldxsuite.crg.eu/parameter/mutant-file)). For inspiration, you can use one of the examples below:
251
+ click on one of the rows to pre-fill the inputs. After specifying the inputs, press the button to predict the effects of mutations on the PPI. Currently the model runs on CPU, so the prediction may take a few minutes.
252
+
253
+ After making a prediction with the model, you will see binding free energy changes (ddG values) for each mutation and a 3D visualization of the PPI with mutated residues highlighted in red. The visualization additionally shows
254
+ the attention coefficients of the model for the nearest neighboring residues, which quantifies the contribution of the residues to the predicted ddG value. The brighted and thicker a reisudes is, the more attention the model paid to it.
255
+ Currently, the web only visualizes the first mutation in the list.
256
+ """)
257
+
258
  with gr.Row():
259
  with gr.Column():
260
  gr.Markdown("## PPI structure")
 
270
 
271
  examples = gr.Examples(
272
  examples=[
273
+ ["1BUI", "A,B,C", "SC16A,FC47A;SC16A;FC47A"],
274
+ ["3QIB", "A,B,P,C,D", "YP7F,TP12S;YP7F;TP12S"],
275
+ ["1KNE", "A,P", ';'.join([f"TP6{a}" for a in AMINO_ACID_CODES_1])]
276
  ],
277
  inputs=[pdb_code, partners, muts],
278
+ label="Examples (click on a line to pre-fill inputs)"
279
  )
280
 
281
  # Predict GUI
 
312
  predict = partial(predict, models, temp_dir)
313
  predict_button.click(predict, inputs=inputs, outputs=outputs)
314
 
315
+ app.launch(allowed_paths=['./assets'])
assets/readme-dimer-close-up.png ADDED