Anton Bushuiev commited on
Commit
5626a5b
·
1 Parent(s): cf2ccb1

Implement basic ddG prediction

Browse files
Files changed (1) hide show
  1. app.py +133 -4
app.py CHANGED
@@ -1,7 +1,136 @@
 
 
 
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
5
 
6
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- iface.launch()
 
1
+ import shutil
2
+ import tempfile
3
+ from pathlib import Path
4
+ from functools import partial
5
+
6
  import gradio as gr
7
+ import numpy as np
8
+ import torch
9
+
10
+ from mutils.pdb import download_pdb
11
+ from ppiref.extraction import PPIExtractor
12
+ from ppiref.utils.ppi import PPIPath
13
+ from ppiformer.tasks.node import DDGPPIformer
14
+ from ppiformer.utils.api import predict_ddg
15
+ from ppiformer.definitions import PPIFORMER_WEIGHTS_DIR
16
+
17
+
18
+ def process_inputs(inputs, temp_dir):
19
+ pdb_code, pdb_path, partners, muts, muts_path = inputs
20
+
21
+ # Check inputs
22
+ if not pdb_code and not pdb_path:
23
+ raise gr.Error("PPI structure not specified.")
24
+
25
+ if pdb_code and pdb_path:
26
+ gr.Warning("Both PDB code and PDB file specified. Using PDB file.")
27
+
28
+ if not partners:
29
+ raise gr.Error("Partners not specified.")
30
+
31
+ if not muts and not muts_path:
32
+ raise gr.Error("Mutations not specified.")
33
+
34
+ if muts and muts_path:
35
+ gr.Warning("Both mutations and mutations file specified. Using mutations file.")
36
+
37
+ # Prepare PDB input
38
+ if pdb_path:
39
+ pdb_path = Path(pdb_path)
40
+ else:
41
+ try:
42
+ pdb_code = pdb_code.strip().lower()
43
+ pdb_path = temp_dir / f'pdb/{pdb_code}.pdb'
44
+ download_pdb(pdb_code, path=pdb_path)
45
+ except:
46
+ raise gr.Error("PDB download failed.")
47
+
48
+ partners = list(map(lambda x: x.strip(), partners.split(',')))
49
+
50
+ # Extract PPI into temp dir
51
+ try:
52
+ ppi_dir = temp_dir / 'ppi'
53
+ extractor = PPIExtractor(out_dir=ppi_dir, nest_out_dir=True, join=True, radius=10.0)
54
+ extractor.extract(pdb_path, partners=partners)
55
+ ppi_path = PPIPath.construct(ppi_dir, pdb_path.stem, partners)
56
+ except:
57
+ raise gr.Error("PPI extraction failed.")
58
+
59
+ # Prepare mutations input
60
+ if muts_path:
61
+ muts_path = Path(muts_path)
62
+ muts = muts_path.read_text()
63
+
64
+ muts = list(map(lambda x: x.strip(), muts.split(';')))
65
+
66
+ return ppi_path, muts
67
+
68
+
69
+ def predict(models, temp_dir, *inputs):
70
+ # Process input
71
+ ppi_path, muts = process_inputs(inputs, temp_dir)
72
+
73
+ print(ppi_path, muts)
74
+
75
+ # Predict
76
+ ddg, attn = predict_ddg(models, ppi_path, muts, return_attn=True)
77
+
78
+ ddg = ddg.detach().numpy().tolist()
79
+ df = list(zip(muts, ddg))
80
+
81
+ return df
82
+
83
+
84
+ app = gr.Blocks()
85
+ with app:
86
+
87
+ # Input GUI
88
+ with gr.Row():
89
+ with gr.Column():
90
+ gr.Markdown("## PPI structure")
91
+ with gr.Row():
92
+ pdb_code = gr.Textbox(placeholder="1BUI", label="PDB code")
93
+ partners = gr.Textbox(placeholder="A,B,C", label="Partners")
94
+ pdb_path = gr.File(file_count="single", label="Or PDB file instead of PDB code")
95
+
96
+ with gr.Column():
97
+ gr.Markdown("## Mutations")
98
+ muts = gr.Textbox(placeholder="SC16A,FC47A", label="List of (multi-point) mutations")
99
+ muts_path = gr.File(file_count="single", label="Or file with mutations")
100
+
101
+ examples = gr.Examples(
102
+ examples=[["1BUI", "A,B,C", "SC16A;FC47A;SC16A,FC47A"]],
103
+ inputs=[pdb_code, partners, muts],
104
+ label="Examples (press line to fill)"
105
+ )
106
+
107
+ # Predict GUI
108
+ predict_button = gr.Button(value="Predict effects of mutations on PPI", variant="primary")
109
+
110
+ # Output GUI
111
+ gr.Markdown("## Predictions")
112
+ df = gr.Dataframe(
113
+ headers=["Mutation", "ddG"],
114
+ datatype=["str", "number"],
115
+ col_count=(2, "fixed"),
116
+ )
117
+
118
+ # Load models
119
+ models = [
120
+ DDGPPIformer.load_from_checkpoint(
121
+ PPIFORMER_WEIGHTS_DIR / f'ddg_regression/{i}.ckpt',
122
+ map_location=torch.device('cpu')
123
+ ).eval()
124
+ for i in range(3)
125
+ ]
126
+
127
+ # Create temporary directory for storing downloaded PDBs and extracted PPIs
128
+ temp_dir_obj = tempfile.TemporaryDirectory()
129
+ temp_dir = Path(temp_dir_obj.name)
130
 
131
+ # Main logic
132
+ inputs = [pdb_code, pdb_path, partners, muts, muts_path]
133
+ predict = partial(predict, models, temp_dir)
134
+ predict_button.click(predict, inputs=inputs, outputs=df)
135
 
136
+ app.launch()