Anton Bushuiev commited on
Commit
c09238a
1 Parent(s): b432a65

Implement visualization dropdown and full complex inference

Browse files
Files changed (1) hide show
  1. app.py +111 -27
app.py CHANGED
@@ -50,7 +50,7 @@ def process_inputs(inputs, temp_dir):
50
 
51
  # Prepare PDB input
52
  if pdb_path:
53
- # remove '-' chars from pdb name
54
  new_pdb_path = temp_dir / f"pdb/{pdb_path.name.replace('_', '-')}"
55
  new_pdb_path.parent.mkdir(parents=True, exist_ok=True)
56
  shutil.copy(str(pdb_path), str(new_pdb_path))
@@ -63,9 +63,13 @@ def process_inputs(inputs, temp_dir):
63
  download_pdb(pdb_code, path=pdb_path)
64
  except:
65
  raise gr.Error("PDB download failed.")
66
-
 
67
  partners = list(map(lambda x: x.strip(), partners.split(',')))
68
 
 
 
 
69
  # Extract PPI into temp dir
70
  try:
71
  ppi_dir = temp_dir / 'ppi'
@@ -80,8 +84,8 @@ def process_inputs(inputs, temp_dir):
80
  muts_path = Path(muts_path)
81
  muts = muts_path.read_text()
82
 
83
- muts = list(map(lambda x: x.strip(), muts.split(';')))
84
-
85
  # Basic format
86
  try:
87
  muts = list(map(lambda m: Mutation.from_str(m), muts.split(';')))
@@ -92,7 +96,7 @@ def process_inputs(inputs, temp_dir):
92
  for mut in muts:
93
  for pmut in mut.muts:
94
  if pmut.chain not in partners:
95
- raise gr.Error(f'Chain of point mutation {pmut} from {mut} is not in the list of partners {partners}.')
96
 
97
  # Consistency with provided .pdb
98
  muts_on_interface = []
@@ -110,8 +114,8 @@ def process_inputs(inputs, temp_dir):
110
  return pdb_path, ppi_path, muts, muts_on_interface
111
 
112
 
113
- def plot_3dmol(pdb_path, ppi_path, muts, attn, mut_id=0):
114
- # 3DMol.js adapted from https://huggingface.co/spaces/huhlim/cg2all/blob/main/app.py
115
 
116
  # Read PDB for 3Dmol.js
117
  with open(pdb_path, "r") as fp:
@@ -127,12 +131,12 @@ def plot_3dmol(pdb_path, ppi_path, muts, attn, mut_id=0):
127
  ppi_df = ppi_df.groupby(list(Residue._fields)).apply(lambda df: df[df['atom_name'] == 'CA'].iloc[0]).reset_index(drop=True)
128
  ppi_df['id'] = ppi_df.apply(lambda row: ':'.join([row['residue_name'], row['chain_id'], str(row['residue_number']), row['insertion']]), axis=1)
129
  ppi_df['id'] = ppi_df['id'].apply(lambda x: x[:-1] if x[-1] == ':' else x)
130
- muts_id = Mutation(muts[mut_id]).wt_to_graphein() # flatten ids of all sp muts
131
  ppi_df['mutated'] = ppi_df.apply(lambda row: row['id'] in muts_id, axis=1)
132
 
133
  # Prepare attention coeffictients per residue (normalized sum of direct attention from mutated residues)
134
  attn = torch.nan_to_num(attn, nan=1e-10)
135
- attn_sub = attn[:, mut_id, 0, :, 0, :, :, :] # models, layers, heads, tokens, tokens
136
  idx_mutated = torch.from_numpy(ppi_df.index[ppi_df['mutated']].to_numpy())
137
  attn_sub = fill_diagonal(attn_sub, 1e-10)
138
  attn_mutated = attn_sub[..., idx_mutated, :]
@@ -235,7 +239,6 @@ def plot_3dmol(pdb_path, ppi_path, muts, attn, mut_id=0):
235
  </script>
236
  </body></html>"""
237
  )
238
- print(html)
239
 
240
  return f"""<iframe style="width: 100%; height: 600px" name="result" allow="midi; geolocation; microphone; camera;
241
  display-capture; encrypted-media;" sandbox="allow-modals allow-forms
@@ -246,29 +249,105 @@ def plot_3dmol(pdb_path, ppi_path, muts, attn, mut_id=0):
246
 
247
  def predict(models, temp_dir, *inputs):
248
  # Process input
249
- pdb_path, ppi_path, muts = process_inputs(inputs, temp_dir)
250
 
251
- print(ppi_path, muts)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
252
 
253
- # Predict
254
- try:
255
- ddg, attn = predict_ddg(models, ppi_path, muts, return_attn=True)
256
- except:
257
- raise gr.Error("Prediction failed. Please double check your inputs.")
258
 
259
- # Create dataframe
260
- ddg = ddg.detach().numpy().tolist()
261
- ddg = np.round(ddg, 3)
262
- df = list(zip(muts, ddg))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
263
 
264
  # Create dataframe file
265
  path = 'ppiformer_ddg_predictions.csv'
266
- pd.DataFrame(df).rename(columns={0: "Mutation", 1: "ddG [kcal/mol]"}).to_csv(path, index=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
267
 
268
- # Create 3DMol plot
269
- plot = plot_3dmol(pdb_path, ppi_path, muts, attn)
270
 
271
- return df, path, plot
 
 
272
 
273
 
274
  app = gr.Blocks(theme=gr.themes.Default(primary_hue="green", secondary_hue="pink"))
@@ -303,7 +382,7 @@ with app:
303
 
304
  with gr.Column():
305
  gr.Markdown("## Mutations")
306
- muts = gr.Textbox(placeholder="SC16A;FC47A;SC16A,FC47A", label="List of (multi-point) mutations", info="SC16A,FC47A;SC16A;FC47A for three mutations: serine to alanine at position 16 in chain C, phenylalanine to alanine at position 47 in chain C, and their double-point combination")
307
  muts_path = gr.File(file_count="single", label="Or file with mutations")
308
 
309
  examples = gr.Examples(
@@ -327,6 +406,8 @@ with app:
327
  datatype=["str", "number"],
328
  col_count=(2, "fixed"),
329
  )
 
 
330
  plot = gr.HTML()
331
 
332
  # Download weights from Zenodo
@@ -347,8 +428,11 @@ with app:
347
 
348
  # Main logic
349
  inputs = [pdb_code, pdb_path, partners, muts, muts_path]
350
- outputs = [df, df_file, plot]
351
  predict = partial(predict, models, temp_dir)
352
  predict_button.click(predict, inputs=inputs, outputs=outputs)
353
 
 
 
 
354
  app.launch(allowed_paths=['./assets'])
 
50
 
51
  # Prepare PDB input
52
  if pdb_path:
53
+ # convert file name to PPIRef format
54
  new_pdb_path = temp_dir / f"pdb/{pdb_path.name.replace('_', '-')}"
55
  new_pdb_path.parent.mkdir(parents=True, exist_ok=True)
56
  shutil.copy(str(pdb_path), str(new_pdb_path))
 
63
  download_pdb(pdb_code, path=pdb_path)
64
  except:
65
  raise gr.Error("PDB download failed.")
66
+
67
+ # Parse partners
68
  partners = list(map(lambda x: x.strip(), partners.split(',')))
69
 
70
+ # Add partners to file name
71
+ pdb_path = pdb_path.rename(pdb_path.with_stem(f"{pdb_path.stem}_{'_'.join(partners)}"))
72
+
73
  # Extract PPI into temp dir
74
  try:
75
  ppi_dir = temp_dir / 'ppi'
 
84
  muts_path = Path(muts_path)
85
  muts = muts_path.read_text()
86
 
87
+ # Check mutations
88
+
89
  # Basic format
90
  try:
91
  muts = list(map(lambda m: Mutation.from_str(m), muts.split(';')))
 
96
  for mut in muts:
97
  for pmut in mut.muts:
98
  if pmut.chain not in partners:
99
+ raise gr.Error(f'Chain of point mutation {pmut} is not in the list of partners {partners}.')
100
 
101
  # Consistency with provided .pdb
102
  muts_on_interface = []
 
114
  return pdb_path, ppi_path, muts, muts_on_interface
115
 
116
 
117
+ def plot_3dmol(pdb_path, ppi_path, mut, attn, attn_mut_id=0):
118
+ # NOTE 3DMol.js adapted from https://huggingface.co/spaces/huhlim/cg2all/blob/main/app.py
119
 
120
  # Read PDB for 3Dmol.js
121
  with open(pdb_path, "r") as fp:
 
131
  ppi_df = ppi_df.groupby(list(Residue._fields)).apply(lambda df: df[df['atom_name'] == 'CA'].iloc[0]).reset_index(drop=True)
132
  ppi_df['id'] = ppi_df.apply(lambda row: ':'.join([row['residue_name'], row['chain_id'], str(row['residue_number']), row['insertion']]), axis=1)
133
  ppi_df['id'] = ppi_df['id'].apply(lambda x: x[:-1] if x[-1] == ':' else x)
134
+ muts_id = Mutation.from_str(mut).wt_to_graphein() # flatten ids of all sp muts
135
  ppi_df['mutated'] = ppi_df.apply(lambda row: row['id'] in muts_id, axis=1)
136
 
137
  # Prepare attention coeffictients per residue (normalized sum of direct attention from mutated residues)
138
  attn = torch.nan_to_num(attn, nan=1e-10)
139
+ attn_sub = attn[:, attn_mut_id, 0, :, 0, :, :, :] # models, layers, heads, tokens, tokens
140
  idx_mutated = torch.from_numpy(ppi_df.index[ppi_df['mutated']].to_numpy())
141
  attn_sub = fill_diagonal(attn_sub, 1e-10)
142
  attn_mutated = attn_sub[..., idx_mutated, :]
 
239
  </script>
240
  </body></html>"""
241
  )
 
242
 
243
  return f"""<iframe style="width: 100%; height: 600px" name="result" allow="midi; geolocation; microphone; camera;
244
  display-capture; encrypted-media;" sandbox="allow-modals allow-forms
 
249
 
250
  def predict(models, temp_dir, *inputs):
251
  # Process input
252
+ pdb_path, ppi_path, muts, muts_on_interface = process_inputs(inputs, temp_dir)
253
 
254
+ # Create dataframe
255
+ df = pd.DataFrame({
256
+ 'Mutation': muts,
257
+ 'ddG [kcal/mol]': len(muts) * [np.nan],
258
+ '10A Interface': muts_on_interface,
259
+ 'Attn Id': len(muts) * [np.nan],
260
+ })
261
+
262
+ # Show warning if some mutations are not on the interface
263
+ muts_not_on_interface = df[~df['10A Interface']]['Mutation'].tolist()
264
+ n_muts_not_on_interface = len(muts_not_on_interface)
265
+ if n_muts_not_on_interface:
266
+ n_muts_warn = 5
267
+ muts_not_on_interface = ';'.join(muts_not_on_interface[:n_muts_warn])
268
+ if n_muts_not_on_interface > n_muts_warn:
269
+ muts_not_on_interface += f'... (and {n_muts_not_on_interface - n_muts_warn} more)'
270
+ gr.Warning((
271
+ f"{muts_not_on_interface} {'is' if n_muts_not_on_interface == 1 else 'are'} not on the interface. "
272
+ "The model will predict the effects of mutations on the whole complex. "
273
+ "This may lead to less accurate predictions."
274
+ ))
275
+
276
+ # Predict using interface for mutations on the interface and using the whole complex otherwise
277
+ attn_ppi, attn_pdb = None, None
278
+ for df_sub, path in [
279
+ [df[df['10A Interface']], ppi_path],
280
+ [df[~df['10A Interface']], pdb_path]
281
+ ]:
282
+ if not len(df_sub):
283
+ continue
284
+
285
+ # Predict
286
+ try:
287
+ ddg, attn = predict_ddg(models, path, df_sub['Mutation'].tolist(), return_attn=True)
288
+ except:
289
+ raise gr.Error("Prediction failed. Please double check your inputs.")
290
+ ddg = ddg.detach().numpy().tolist()
291
 
292
+ # Update dataframe and attention tensor
293
+ idx = df_sub.index
294
+ df.loc[idx, 'ddG [kcal/mol]'] = ddg
295
+ df.loc[idx, 'Attn Id'] = np.arange(len(idx))
 
296
 
297
+ if path == ppi_path:
298
+ attn_ppi = attn
299
+ else:
300
+ attn_pdb = attn
301
+ df['Attn Id'] = df['Attn Id'].astype(int)
302
+
303
+
304
+ # Round ddG values
305
+ df['ddG [kcal/mol]'] = df['ddG [kcal/mol]'].round(3)
306
+
307
+ # Create PPI-specific dropdown
308
+ dropdown = gr.Dropdown(
309
+ df['Mutation'].tolist(), value=df['Mutation'].iloc[0],
310
+ interactive=True, visible=True, label="Mutation to visualize",
311
+ )
312
+
313
+ # Predefine plot arguments for all dropdown choices
314
+ dropdown_choices_to_plot_args = {
315
+ mut: (
316
+ pdb_path,
317
+ ppi_path if df[df['Mutation'] == mut]['10A Interface'].iloc[0] else pdb_path,
318
+ mut,
319
+ attn_ppi if df[df['Mutation'] == mut]['10A Interface'].iloc[0] else attn_pdb,
320
+ df[df['Mutation'] == mut]['Attn Id'].iloc[0]
321
+ )
322
+ for mut in df['Mutation']
323
+ }
324
 
325
  # Create dataframe file
326
  path = 'ppiformer_ddg_predictions.csv'
327
+ if n_muts_not_on_interface:
328
+ df = df[['Mutation', 'ddG [kcal/mol]', '10A Interface']]
329
+ df.to_csv(path, index=False)
330
+ df = gr.Dataframe(
331
+ value=df,
332
+ headers=['Mutation', 'ddG [kcal/mol]', '10A Interface'],
333
+ datatype=['str', 'number', 'bool'],
334
+ col_count=(3, 'fixed'),
335
+ )
336
+ else:
337
+ df = df[['Mutation', 'ddG [kcal/mol]']]
338
+ df.to_csv(path, index=False)
339
+ df = gr.Dataframe(
340
+ value=df,
341
+ headers=['Mutation', 'ddG [kcal/mol]'],
342
+ datatype=['str', 'number'],
343
+ col_count=(2, 'fixed'),
344
+ )
345
 
346
+ return df, path, dropdown, dropdown_choices_to_plot_args
 
347
 
348
+
349
+ def update_plot(dropdown, dropdown_choices_to_plot_args):
350
+ return plot_3dmol(*dropdown_choices_to_plot_args[dropdown])
351
 
352
 
353
  app = gr.Blocks(theme=gr.themes.Default(primary_hue="green", secondary_hue="pink"))
 
382
 
383
  with gr.Column():
384
  gr.Markdown("## Mutations")
385
+ muts = gr.Textbox(placeholder="SC16A;FC47A;SC16A,FC47A", label="List of (multi-point) mutations", info="SC16A;FC47A;SC16A,FC47A for three mutations: serine to alanine at position 16 in chain C, phenylalanine to alanine at position 47 in chain C, and their double-point combination")
386
  muts_path = gr.File(file_count="single", label="Or file with mutations")
387
 
388
  examples = gr.Examples(
 
406
  datatype=["str", "number"],
407
  col_count=(2, "fixed"),
408
  )
409
+ dropdown = gr.Dropdown(interactive=True, visible=False)
410
+ dropdown_choices_to_plot_args = gr.State([])
411
  plot = gr.HTML()
412
 
413
  # Download weights from Zenodo
 
428
 
429
  # Main logic
430
  inputs = [pdb_code, pdb_path, partners, muts, muts_path]
431
+ outputs = [df, df_file, dropdown, dropdown_choices_to_plot_args]
432
  predict = partial(predict, models, temp_dir)
433
  predict_button.click(predict, inputs=inputs, outputs=outputs)
434
 
435
+ # Update plot on dropdown change
436
+ dropdown.change(update_plot, inputs=[dropdown, dropdown_choices_to_plot_args], outputs=[plot])
437
+
438
  app.launch(allowed_paths=['./assets'])