Spaces:
Runtime error
Runtime error
Simon Duerr
commited on
Commit
·
b4346be
1
Parent(s):
3897ec7
add files
Browse files- app.py +245 -11
- utils/__pycache__/helpers.cpython-38.pyc +0 -0
- utils/__pycache__/model.cpython-38.pyc +0 -0
- utils/__pycache__/voxelization.cpython-38.pyc +0 -0
- utils/helpers.py +272 -0
- utils/model.py +42 -0
- utils/voxelization.py +209 -0
- weights/Metal3D.pth +3 -0
- weights/metal_0.5A_v3_d0.2_16Abox.pth +3 -0
app.py
CHANGED
@@ -1,19 +1,253 @@
|
|
1 |
import gradio as gr
|
2 |
|
3 |
-
|
4 |
-
|
|
|
|
|
5 |
|
6 |
-
|
|
|
|
|
|
|
7 |
|
8 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
gr.Markdown("# Metal3D")
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
with gr.Group():
|
11 |
-
inp = gr.Textbox(
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
mol = gr.HTML()
|
17 |
-
btn.click(fn=update, inputs=inp, outputs=out)
|
|
|
|
|
18 |
|
19 |
-
demo.launch()
|
|
|
1 |
import gradio as gr
|
2 |
|
3 |
+
import urllib
|
4 |
+
import re
|
5 |
+
import sys
|
6 |
+
import warnings
|
7 |
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import ipywidgets as widgets
|
11 |
+
from ipywidgets import interact, fixed
|
12 |
|
13 |
+
from utils.helpers import *
|
14 |
+
from utils.voxelization import processStructures
|
15 |
+
from utils.model import Model
|
16 |
+
import numpy as np
|
17 |
+
|
18 |
+
import os
|
19 |
+
|
20 |
+
def update(inp, file, mode):
|
21 |
+
try:
|
22 |
+
pdb_file = file.name
|
23 |
+
except:
|
24 |
+
print("using pdbfile")
|
25 |
+
|
26 |
+
try:
|
27 |
+
pdb_file = inp
|
28 |
+
if (
|
29 |
+
re.match(
|
30 |
+
"[OPQ][0-9][A-Z0-9]{3}[0-9]|[A-NR-Z][0-9]([A-Z][A-Z0-9]{2}[0-9]){1,2}",
|
31 |
+
pdb_file,
|
32 |
+
).group()
|
33 |
+
== pdb_file
|
34 |
+
):
|
35 |
+
urllib.request.urlretrieve(
|
36 |
+
f"https://alphafold.ebi.ac.uk/files/AF-{pdb_file}-F1-model_v2.pdb",
|
37 |
+
f"files/{pdb_file}.pdb",
|
38 |
+
)
|
39 |
+
except AttributeError:
|
40 |
+
if len(inp) == 4:
|
41 |
+
pdb_file = inp
|
42 |
+
urllib.request.urlretrieve(
|
43 |
+
f"http://files.rcsb.org/download/{pdb_file.lower()}.pdb1",
|
44 |
+
f"files/{pdb_file}.pdb",
|
45 |
+
)
|
46 |
+
else:
|
47 |
+
return "pdb code must be 4 letters or Uniprot code does not match", ""
|
48 |
+
|
49 |
+
if mode == "All residues":
|
50 |
+
ids = get_all_protein_resids(
|
51 |
+
f"files/{pdb_file}.pdb",
|
52 |
+
)
|
53 |
+
else:
|
54 |
+
ids = get_all_metalbinding_resids(f"files/{pdb_file}.pdb")
|
55 |
+
|
56 |
+
voxels, prot_centers, prot_N, prots = processStructures(pdb_file, ids)
|
57 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
58 |
+
voxels.to(device)
|
59 |
+
print(voxels.shape)
|
60 |
+
model = Model()
|
61 |
+
model.to(device)
|
62 |
+
model.load_state_dict(torch.load("weights/metal_0.5A_v3_d0.2_16Abox.pth"))
|
63 |
+
model.eval()
|
64 |
+
with warnings.catch_warnings():
|
65 |
+
warnings.filterwarnings("ignore")
|
66 |
+
output = model(voxels)
|
67 |
+
print(output.shape)
|
68 |
+
prot_v = np.vstack(prot_centers)
|
69 |
+
output_v = output.flatten().cpu().detach().numpy()
|
70 |
+
bb = get_bb(prot_v)
|
71 |
+
gridres = 0.5
|
72 |
+
grid, box_N = create_grid_fromBB(bb, voxelSize=gridres)
|
73 |
+
probability_values = get_probability_mean(grid, prot_v, output_v)
|
74 |
+
print(probability_values.shape)
|
75 |
+
write_cubefile(
|
76 |
+
bb,
|
77 |
+
probability_values,
|
78 |
+
box_N,
|
79 |
+
outname=f"output/metal_{pdb_file}.cube",
|
80 |
+
gridres=gridres,
|
81 |
+
)
|
82 |
+
message = find_unique_sites(
|
83 |
+
probability_values,
|
84 |
+
grid,
|
85 |
+
writeprobes=True,
|
86 |
+
probefile=f"output/probes_{pdb_file}.pdb",
|
87 |
+
threshold=7,
|
88 |
+
p=0.15,
|
89 |
+
)
|
90 |
+
|
91 |
+
return message, molecule(
|
92 |
+
f"files/{pdb_file}.pdb",
|
93 |
+
f"output/probes_{pdb_file}.pdb",
|
94 |
+
f"output/metal_{pdb_file}.cube",
|
95 |
+
)
|
96 |
+
|
97 |
+
|
98 |
+
def test():
|
99 |
+
x = """<!DOCTYPE html>
|
100 |
+
<html>
|
101 |
+
<head>
|
102 |
+
<meta http-equiv="content-type" content="text/html; charset=UTF-8" />
|
103 |
+
</head>
|
104 |
+
<body>
|
105 |
+
<script src="https://3Dmol.org/build/3Dmol-min.js" async></script> <div style="height: 400px; width: 400px; position: relative;" class="viewer_3Dmoljs" data-pdb="2POR" data-backgroundcolor="0xffffff" data-style="stick" ></div>
|
106 |
+
</body></html>"""
|
107 |
+
return f"""<iframe style="width: 100%; height: 480px" name="result" allow="midi; geolocation; microphone; camera;
|
108 |
+
display-capture; encrypted-media;" sandbox="allow-modals allow-forms
|
109 |
+
allow-scripts allow-same-origin allow-popups
|
110 |
+
allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
|
111 |
+
allowpaymentrequest="" frameborder="0" srcdoc='{x}'></iframe>"""
|
112 |
+
|
113 |
+
|
114 |
+
def read_mol(molpath):
|
115 |
+
with open(molpath, "r") as fp:
|
116 |
+
lines = fp.readlines()
|
117 |
+
mol = ""
|
118 |
+
for l in lines:
|
119 |
+
mol += l
|
120 |
+
return mol
|
121 |
+
|
122 |
+
|
123 |
+
def molecule(pdb, probes, cube):
|
124 |
+
mol = read_mol(pdb)
|
125 |
+
probes = read_mol(probes)
|
126 |
+
cubefile = read_mol(cube)
|
127 |
+
x = (
|
128 |
+
"""<!DOCTYPE html>
|
129 |
+
<html>
|
130 |
+
<head>
|
131 |
+
<meta http-equiv="content-type" content="text/html; charset=UTF-8" />
|
132 |
+
<style>
|
133 |
+
body{
|
134 |
+
font-family:sans-serif
|
135 |
+
}
|
136 |
+
.mol-container {
|
137 |
+
width: 100%;
|
138 |
+
height: 400px;
|
139 |
+
position: relative;
|
140 |
+
}
|
141 |
+
.slider{
|
142 |
+
width:80%;
|
143 |
+
margin:0 auto
|
144 |
+
}
|
145 |
+
.slidercontainer{
|
146 |
+
display:flex;
|
147 |
+
}
|
148 |
+
.slidercontainer > * + * {
|
149 |
+
margin-left: 0.5rem;
|
150 |
+
}
|
151 |
+
#isovalue{
|
152 |
+
text-align:right}
|
153 |
+
</style>
|
154 |
+
<script src="https://3Dmol.csb.pitt.edu/build/3Dmol-min.js"></script>
|
155 |
+
<script src="https://cdnjs.cloudflare.com/ajax/libs/rangeslider.js/2.3.3/rangeslider.min.js" integrity="sha512-BUlWdwDeJo24GIubM+z40xcj/pjw7RuULBkxOTc+0L9BaGwZPwiwtbiSVzv31qR7TWx7bs6OPTE5IyfLOorboQ==" crossorigin="anonymous" referrerpolicy="no-referrer"></script>
|
156 |
+
</head>
|
157 |
+
<body>
|
158 |
+
<div class="slidercontainer">
|
159 |
+
<span>Isovalue </span>
|
160 |
+
<span id="isovalue">0.5</span>
|
161 |
+
<input class="slider" type="range" id="rangeslider" min="0" max="1" step="0.025" value=0.5>
|
162 |
+
</div>
|
163 |
+
|
164 |
+
<div id="container" class="mol-container"></div>
|
165 |
+
<script>
|
166 |
+
let viewer = null;
|
167 |
+
let voldata = null;
|
168 |
+
$(document).ready(function () {
|
169 |
+
let element = $("#container");
|
170 |
+
let config = { backgroundColor: "white" };
|
171 |
+
viewer = $3Dmol.createViewer( element, config );
|
172 |
+
viewer.ui.initiateUI();
|
173 |
+
let data = `"""
|
174 |
+
+ mol
|
175 |
+
+ """`
|
176 |
+
viewer.addModel( data, "pdb" );
|
177 |
+
|
178 |
+
let cubefile = `"""
|
179 |
+
+ cubefile
|
180 |
+
+ """`
|
181 |
+
voldata = new $3Dmol.VolumeData(cubefile, "cube");
|
182 |
+
viewer.addIsosurface(voldata, { isoval: 0.7 , color: "blue", alpha: 0.85, smoothness: 1 });
|
183 |
+
viewer.getModel(0).setStyle({}, {cartoon: {color: "grayCarbon"}});
|
184 |
+
let probes =`"""
|
185 |
+
+ probes
|
186 |
+
+ """`
|
187 |
+
viewer.addModel(probes, "pdb");
|
188 |
+
viewer.getModel(1).setStyle({ "resn": "ZN" }, { "sphere": { }});
|
189 |
+
viewer.getModel(1).setHoverable({}, true,
|
190 |
+
function (atom, viewer, event, container) {
|
191 |
+
if (!atom.label) {
|
192 |
+
atom.label = viewer.addLabel("ZN p=" + atom.pdbline.substring(55, 60), { position: atom, backgroundColor: "mintcream", fontColor: "black" });
|
193 |
+
}
|
194 |
+
},
|
195 |
+
function (atom, viewer) {
|
196 |
+
if (atom.label) {
|
197 |
+
viewer.removeLabel(atom.label);
|
198 |
+
delete atom.label;
|
199 |
+
}
|
200 |
+
}
|
201 |
+
);
|
202 |
+
viewer.zoomTo();
|
203 |
+
viewer.render();
|
204 |
+
viewer.zoom(0.8, 2000);
|
205 |
+
});
|
206 |
+
</script>
|
207 |
+
<script>
|
208 |
+
$("#rangeslider").rangeslider().on("change", function (el) {
|
209 |
+
isoval = parseFloat(el.target.value);
|
210 |
+
$("#isovalue").text(el.target.value)
|
211 |
+
viewer.addIsosurface(voldata, { isoval: parseFloat(el.target.value), color: "blue", alpha: 0.85, smoothness: 1 });
|
212 |
+
viewer.render();
|
213 |
+
});
|
214 |
+
</script>
|
215 |
+
</body></html>"""
|
216 |
+
)
|
217 |
+
|
218 |
+
return f"""<iframe style="width: 100%; height: 480px" name="result" allow="midi; geolocation; microphone; camera;
|
219 |
+
display-capture; encrypted-media;" sandbox="allow-modals allow-forms
|
220 |
+
allow-scripts allow-same-origin allow-popups
|
221 |
+
allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
|
222 |
+
allowpaymentrequest="" frameborder="0" srcdoc='{x}'></iframe>"""
|
223 |
+
|
224 |
+
|
225 |
+
metal3d = gr.Blocks()
|
226 |
+
|
227 |
+
with metal3d:
|
228 |
gr.Markdown("# Metal3D")
|
229 |
+
gr.Markdown(
|
230 |
+
"""
|
231 |
+
Details about implementation and code available here:
|
232 |
+
>Duerr, Levy and Roethlisberger, Predicting zinc ion location using deep learning, BioRxiv, 2022 "
|
233 |
+
"""
|
234 |
+
)
|
235 |
with gr.Group():
|
236 |
+
inp = gr.Textbox(
|
237 |
+
placeholder="PDB Code or Uniprot identifier", label="Input molecule"
|
238 |
+
)
|
239 |
+
gr.Markdown("or upload a file")
|
240 |
+
file = gr.File(file_count="single", type="file")
|
241 |
+
mode = gr.Radio(
|
242 |
+
["All metalbinding residues (ASP, CYS, GLU, HIS)", "All residues"],
|
243 |
+
label="Residues to use for prediction",
|
244 |
+
)
|
245 |
+
btn = gr.Button("Run")
|
246 |
+
|
247 |
+
gr.Markdown("# Output")
|
248 |
+
out = gr.Textbox(label="status")
|
249 |
mol = gr.HTML()
|
250 |
+
btn.click(fn=update, inputs=[inp, file, mode], outputs=[out, mol])
|
251 |
+
|
252 |
+
metal3d.launch()
|
253 |
|
|
utils/__pycache__/helpers.cpython-38.pyc
ADDED
Binary file (7.65 kB). View file
|
|
utils/__pycache__/model.cpython-38.pyc
ADDED
Binary file (1.41 kB). View file
|
|
utils/__pycache__/voxelization.cpython-38.pyc
ADDED
Binary file (4.99 kB). View file
|
|
utils/helpers.py
ADDED
@@ -0,0 +1,272 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import multiprocessing
|
3 |
+
from multiprocessing import Pool
|
4 |
+
from turtle import width
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
from moleculekit.molecule import Molecule
|
9 |
+
from scipy.spatial import KDTree
|
10 |
+
from sklearn.cluster import AgglomerativeClustering
|
11 |
+
|
12 |
+
|
13 |
+
def create_grid_fromBB(boundingBox, voxelSize=1):
|
14 |
+
"""Create a grid from a bounding box.
|
15 |
+
|
16 |
+
Parameters
|
17 |
+
----------
|
18 |
+
boundingBox : list
|
19 |
+
List of the form [xmin, xmax, ymin, ymax, zmin, zmax]
|
20 |
+
voxelSize : float
|
21 |
+
Size of the voxels in Angstrom
|
22 |
+
|
23 |
+
Returns
|
24 |
+
-------
|
25 |
+
grid : numpy.ndarray
|
26 |
+
Grid of shape (nx, ny, nz)
|
27 |
+
box_N : numpy.ndarray
|
28 |
+
Number of voxels in each dimension
|
29 |
+
|
30 |
+
"""
|
31 |
+
# increase grid by 0.5 to sample everything
|
32 |
+
xrange = np.arange(boundingBox[0][0], boundingBox[1][0] + 0.5, step=voxelSize)
|
33 |
+
yrange = np.arange(boundingBox[0][1], boundingBox[1][1] + 0.5, step=voxelSize)
|
34 |
+
zrange = np.arange(boundingBox[0][2], boundingBox[1][2] + 0.5, step=voxelSize)
|
35 |
+
|
36 |
+
gridpoints = np.zeros((xrange.shape[0] * yrange.shape[0] * zrange.shape[0], 3))
|
37 |
+
i = 0
|
38 |
+
for x in xrange:
|
39 |
+
for y in yrange:
|
40 |
+
for z in zrange:
|
41 |
+
gridpoints[i][0] = x
|
42 |
+
gridpoints[i][1] = y
|
43 |
+
gridpoints[i][2] = z
|
44 |
+
i += 1
|
45 |
+
return gridpoints, (xrange.shape[0], yrange.shape[0], zrange.shape[0])
|
46 |
+
|
47 |
+
|
48 |
+
def get_bb(points):
|
49 |
+
"""Return bounding box from a set of points (N,3)
|
50 |
+
|
51 |
+
Parameters
|
52 |
+
----------
|
53 |
+
points : numpy.ndarray
|
54 |
+
Set of points (N,3)
|
55 |
+
|
56 |
+
Returns
|
57 |
+
-------
|
58 |
+
boundingBox : list
|
59 |
+
List of the form [xmin, xmax, ymin, ymax, zmin, zmax]
|
60 |
+
|
61 |
+
"""
|
62 |
+
minx = np.min(points[:, 0])
|
63 |
+
maxx = np.max(points[:, 0])
|
64 |
+
|
65 |
+
miny = np.min(points[:, 1])
|
66 |
+
maxy = np.max(points[:, 1])
|
67 |
+
|
68 |
+
minz = np.min(points[:, 2])
|
69 |
+
maxz = np.max(points[:, 2])
|
70 |
+
bb = [[minx, miny, minz], [maxx, maxy, maxz]]
|
71 |
+
return bb
|
72 |
+
|
73 |
+
|
74 |
+
def get_all_protein_resids(pdb_file):
|
75 |
+
"""Return all protein residues from a pdb file
|
76 |
+
|
77 |
+
Parameters
|
78 |
+
----------
|
79 |
+
pdb_file : str
|
80 |
+
Path to pdb file
|
81 |
+
|
82 |
+
Returns
|
83 |
+
-------
|
84 |
+
resids : numpy.ndarray
|
85 |
+
Array of protein resids old -> new
|
86 |
+
|
87 |
+
"""
|
88 |
+
try:
|
89 |
+
prot = Molecule(pdb_file)
|
90 |
+
except:
|
91 |
+
exit("could not read file")
|
92 |
+
prot.filter("protein")
|
93 |
+
return prot.get("index", sel="name CA")
|
94 |
+
|
95 |
+
|
96 |
+
def get_all_metalbinding_resids(pdb_file):
|
97 |
+
"""Return all metal binding residues from a pdb file
|
98 |
+
|
99 |
+
Parameters
|
100 |
+
----------
|
101 |
+
pdb_file : str
|
102 |
+
Path to pdb file
|
103 |
+
|
104 |
+
Returns
|
105 |
+
-------
|
106 |
+
resids : numpy.ndarray
|
107 |
+
id of resids that are metal binding
|
108 |
+
|
109 |
+
"""
|
110 |
+
|
111 |
+
try:
|
112 |
+
prot = Molecule(pdb_file)
|
113 |
+
except:
|
114 |
+
exit("could not read file")
|
115 |
+
prot.filter("protein")
|
116 |
+
return prot.get(
|
117 |
+
"index",
|
118 |
+
sel="name CA and resname HIS HID HIE HIP CYS CYX GLU GLH GLN ASP ASH ASN GLN MET",
|
119 |
+
)
|
120 |
+
|
121 |
+
|
122 |
+
def compute_average_p_fast(point, cutoff=1):
|
123 |
+
"""Using KDTree find the closest gridpoints
|
124 |
+
|
125 |
+
Parameters
|
126 |
+
----------
|
127 |
+
point : numpy.ndarray
|
128 |
+
Point of shape (3,)
|
129 |
+
cutoff : float
|
130 |
+
Cutoff distance in Angstrom
|
131 |
+
|
132 |
+
Returns
|
133 |
+
-------
|
134 |
+
average_p : numpy.ndarray
|
135 |
+
Average probability of shape (1,)"""
|
136 |
+
p = 0
|
137 |
+
nearest_neighbors, indices = tree.query(
|
138 |
+
point, k=15, distance_upper_bound=cutoff, workers=1
|
139 |
+
)
|
140 |
+
if np.min(nearest_neighbors) != np.inf:
|
141 |
+
p = np.mean(output_v[indices[nearest_neighbors != np.inf]])
|
142 |
+
return p
|
143 |
+
|
144 |
+
|
145 |
+
def get_probability_mean(grid, prot_centers, pvalues):
|
146 |
+
"""Compute the mean probability of all gridpoints from the globalgrid based on the individual boxes
|
147 |
+
|
148 |
+
Parameters
|
149 |
+
----------
|
150 |
+
grid : numpy.ndarray
|
151 |
+
Grid of shape (nx, ny, nz)
|
152 |
+
prot_centers : numpy.ndarray
|
153 |
+
Protein centers of shape (N,3)
|
154 |
+
pvalues : numpy.ndarray
|
155 |
+
Probability values of shape (N,1)
|
156 |
+
|
157 |
+
Returns
|
158 |
+
-------
|
159 |
+
mean_p : numpy.ndarray
|
160 |
+
Mean probability over grid of shape (nx, ny, nz)
|
161 |
+
"""
|
162 |
+
global output_v
|
163 |
+
output_v = pvalues
|
164 |
+
global prot_v
|
165 |
+
prot_v = prot_centers
|
166 |
+
cpuCount = multiprocessing.cpu_count()
|
167 |
+
|
168 |
+
global tree
|
169 |
+
tree = KDTree(prot_v)
|
170 |
+
p = Pool(cpuCount)
|
171 |
+
results = p.map(compute_average_p_fast, grid)
|
172 |
+
return np.array(results)
|
173 |
+
|
174 |
+
|
175 |
+
def write_cubefile(bb, pvalues, box_N, outname="Metal3D_pmap.cube", gridres=1):
|
176 |
+
"""Write a cube file from a probability map
|
177 |
+
The cube specification from gaussian is used, distance are converted to bohr
|
178 |
+
|
179 |
+
Parameters
|
180 |
+
----------
|
181 |
+
bb : list
|
182 |
+
List of the form [xmin, xmax, ymin, ymax, zmin, zmax]
|
183 |
+
pvalues : numpy.ndarray
|
184 |
+
Probability values of shape (nx, ny, nz)
|
185 |
+
box_N : tuple
|
186 |
+
Number of voxels in each dimension
|
187 |
+
outname : str
|
188 |
+
Name of the output file
|
189 |
+
gridres:float
|
190 |
+
Resolution of the grid used for writing the voxels
|
191 |
+
|
192 |
+
"""
|
193 |
+
|
194 |
+
with open(outname, "w") as cube:
|
195 |
+
cube.write(" Metal3D Cube File\n")
|
196 |
+
cube.write(" Outer Loop: X, Middle Loop y, inner Loop z\n")
|
197 |
+
|
198 |
+
angstromToBohr = 1.89
|
199 |
+
cube.write(
|
200 |
+
f" 1 {bb[0][0]*angstromToBohr: .6f} {bb[0][1]*angstromToBohr: .6f} {bb[0][2]*angstromToBohr: .6f}\n"
|
201 |
+
)
|
202 |
+
cube.write(
|
203 |
+
f"{str(box_N[0]).rjust(5)} {1.890000*gridres:.9f} 0.000000 0.000000\n"
|
204 |
+
)
|
205 |
+
cube.write(
|
206 |
+
f"{str(box_N[1]).rjust(5)} 0.000000 {1.890000*gridres:.9f} 0.000000\n"
|
207 |
+
)
|
208 |
+
cube.write(
|
209 |
+
f"{str(box_N[2]).rjust(5)} 0.000000 0.000000 {1.890000*gridres:.9f}\n"
|
210 |
+
)
|
211 |
+
cube.write(" 1 1.000000 0.000000 0.000000 0.000000\n")
|
212 |
+
|
213 |
+
o = pvalues.reshape(box_N)
|
214 |
+
for x in range(box_N[0]):
|
215 |
+
for y in range(box_N[1]):
|
216 |
+
for z in range(box_N[2]):
|
217 |
+
cube.write(f" {o[x][y][z]: .5E}")
|
218 |
+
if z % 6 == 5:
|
219 |
+
cube.write("\n")
|
220 |
+
cube.write("\n")
|
221 |
+
|
222 |
+
|
223 |
+
def find_unique_sites(
|
224 |
+
pvalues, grid, writeprobes=False, probefile="probes.pdb", threshold=5, p=0.75
|
225 |
+
):
|
226 |
+
"""The probability voxels are points and the voxel clouds may contain multiple metals
|
227 |
+
This function finds the unique sites and returns the coordinates of the unique sites with the highest p for each cluster.
|
228 |
+
It uses the AgglomerativeClustering algorithm to find the unique sites.
|
229 |
+
The threshold is the maximum distance between two points in the same cluster it can be changed to get more metal points.
|
230 |
+
|
231 |
+
Parameters
|
232 |
+
----------
|
233 |
+
pvalues : numpy.ndarray
|
234 |
+
Probability values of shape (N, 1)
|
235 |
+
grid : numpy.ndarray
|
236 |
+
Grid of shape (N, 3)
|
237 |
+
writeprobes : bool
|
238 |
+
If True, write the probes to a pdb file
|
239 |
+
probefile : str
|
240 |
+
Name of the output file
|
241 |
+
threshold : float
|
242 |
+
Maximum distance between two points in the same cluster
|
243 |
+
p : float
|
244 |
+
Minimum probability of a point to be considered a unique site
|
245 |
+
|
246 |
+
"""
|
247 |
+
|
248 |
+
points = grid[pvalues > p]
|
249 |
+
point_p = pvalues[pvalues > p]
|
250 |
+
if len(points) == 0:
|
251 |
+
return "no metals found"
|
252 |
+
clustering = AgglomerativeClustering(
|
253 |
+
n_clusters=None, linkage="complete", distance_threshold=threshold
|
254 |
+
).fit(points)
|
255 |
+
|
256 |
+
message = f"min metal p={p}, n(metals) found: {clustering.n_clusters_}"
|
257 |
+
|
258 |
+
sites = []
|
259 |
+
for i in range(clustering.n_clusters_):
|
260 |
+
c_points = points[clustering.labels_ == i]
|
261 |
+
c_points_p = point_p[clustering.labels_ == i]
|
262 |
+
|
263 |
+
position = c_points[np.argmax(c_points_p)]
|
264 |
+
sites.append((position, np.max(c_points_p)))
|
265 |
+
if writeprobes:
|
266 |
+
print(f"writing probes to {probefile}")
|
267 |
+
with open(probefile, "w") as f:
|
268 |
+
for i, site in enumerate(sites):
|
269 |
+
f.write(
|
270 |
+
f"HETATM {i+1:3} ZN ZN A {i+1:3} {site[0][0]: 8.3f}{site[0][1]: 8.3f}{site[0][2]: 8.3f} {site[1]:.2f} 0.0 ZN2+\n"
|
271 |
+
)
|
272 |
+
return message
|
utils/model.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import warnings
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
|
8 |
+
class Model(nn.Module):
|
9 |
+
"""Model with same padding
|
10 |
+
Conv5 uses a large filter size to aggregate the features from the whole box"""
|
11 |
+
|
12 |
+
def __init__(self):
|
13 |
+
super(Model, self).__init__()
|
14 |
+
self.conv1 = nn.Conv3d(8, 32, 3, padding="same")
|
15 |
+
self.conv2 = nn.Conv3d(32, 64, 3, padding="same")
|
16 |
+
self.conv3 = nn.Conv3d(64, 80, 3, padding="same")
|
17 |
+
self.conv4 = nn.Conv3d(80, 20, 3, padding="same")
|
18 |
+
self.conv5 = nn.Conv3d(20, 20, 20, padding="same")
|
19 |
+
self.conv6 = nn.Conv3d(20, 16, 3, padding="same")
|
20 |
+
self.conv7 = nn.Conv3d(16, 1, 3, padding="same")
|
21 |
+
self.dropout1 = nn.Dropout(0.2)
|
22 |
+
|
23 |
+
def forward(self, x):
|
24 |
+
x = self.conv1(x)
|
25 |
+
x = F.relu(x)
|
26 |
+
x = self.conv2(x)
|
27 |
+
x = F.relu(x)
|
28 |
+
x = self.conv3(x)
|
29 |
+
x = F.relu(x)
|
30 |
+
|
31 |
+
x = self.conv4(x)
|
32 |
+
x = F.relu(x)
|
33 |
+
|
34 |
+
x = self.conv5(x)
|
35 |
+
x = F.relu(x)
|
36 |
+
x = self.dropout1(x)
|
37 |
+
x = self.conv6(x)
|
38 |
+
x = F.relu(x)
|
39 |
+
|
40 |
+
x = self.conv7(x)
|
41 |
+
x = torch.sigmoid(x)
|
42 |
+
return x
|
utils/voxelization.py
ADDED
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
import multiprocessing
|
3 |
+
from multiprocessing import Pool
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
from moleculekit.molecule import Molecule
|
9 |
+
from moleculekit.tools.voxeldescriptors import getVoxelDescriptors
|
10 |
+
from moleculekit.tools.atomtyper import prepareProteinForAtomtyping
|
11 |
+
from moleculekit.tools.preparation import systemPrepare
|
12 |
+
|
13 |
+
|
14 |
+
class AtomtypingError(Exception):
|
15 |
+
pass
|
16 |
+
|
17 |
+
|
18 |
+
class StructureCleaningError(Exception):
|
19 |
+
pass
|
20 |
+
|
21 |
+
|
22 |
+
class ProteinPrepareError(Exception):
|
23 |
+
pass
|
24 |
+
|
25 |
+
|
26 |
+
class VoxelizationError(Exception):
|
27 |
+
pass
|
28 |
+
|
29 |
+
|
30 |
+
metal_atypes = (
|
31 |
+
"MG",
|
32 |
+
"ZN",
|
33 |
+
"MN",
|
34 |
+
"CA",
|
35 |
+
"FE",
|
36 |
+
"HG",
|
37 |
+
"CD",
|
38 |
+
"NI",
|
39 |
+
"CO",
|
40 |
+
"CU",
|
41 |
+
"K",
|
42 |
+
"LI",
|
43 |
+
"Mg",
|
44 |
+
"Zn",
|
45 |
+
"Mn",
|
46 |
+
"Ca",
|
47 |
+
"Fe",
|
48 |
+
"Hg",
|
49 |
+
"Cd",
|
50 |
+
"Ni",
|
51 |
+
"Co",
|
52 |
+
"Cu",
|
53 |
+
"Li",
|
54 |
+
)
|
55 |
+
|
56 |
+
|
57 |
+
def voxelize_single_notcentered(env):
|
58 |
+
"""voxelize 1 structure, executed on a single CPU
|
59 |
+
Using 7 of the 8 channels supplied by moleculekit(excluding metals)
|
60 |
+
Additionally it uses all the metalbinding residues as channel
|
61 |
+
|
62 |
+
Parameters
|
63 |
+
----------
|
64 |
+
env : tuple
|
65 |
+
Tuple of the form (prot, idx)
|
66 |
+
|
67 |
+
Returns
|
68 |
+
-------
|
69 |
+
voxels : torch.tensor
|
70 |
+
Voxelized structure with 8 channels (8,20,20,20)
|
71 |
+
prot_centers : list
|
72 |
+
List of the centers of the voxels (20x20x20,3)
|
73 |
+
prot_n : list
|
74 |
+
List of the number of voxels in each voxel (20x20x20)
|
75 |
+
prot : moleculekit.Molecule
|
76 |
+
Moleculekit molecule
|
77 |
+
"""
|
78 |
+
prot, id = env
|
79 |
+
|
80 |
+
c = prot.get("coords", sel=f"index {id} and name CA")
|
81 |
+
|
82 |
+
size = [16, 16, 16] # size of box
|
83 |
+
voxels = torch.zeros(8, 32, 32, 32)
|
84 |
+
|
85 |
+
try:
|
86 |
+
hydrophobic = prot.atomselect("element C")
|
87 |
+
hydrophobic = hydrophobic.reshape(hydrophobic.shape[0], 1)
|
88 |
+
|
89 |
+
aromatic = prot.atomselect(
|
90 |
+
"resname HIS HIE HIP HID TRP TYR PHE and sidechain and not name CB and not hydrogen"
|
91 |
+
)
|
92 |
+
aromatic = aromatic.reshape(aromatic.shape[0], 1)
|
93 |
+
|
94 |
+
metalcoordination = prot.atomselect(
|
95 |
+
"(name ND1 NE2 SG OE1 OE2 OD2) or (protein and name O N)"
|
96 |
+
)
|
97 |
+
metalcoordination = metalcoordination.reshape(metalcoordination.shape[0], 1)
|
98 |
+
|
99 |
+
hbondacceptor = prot.atomselect(
|
100 |
+
"(resname ASP GLU HIS HIE HIP HID SER THR MSE CYS MET and name ND2 NE2 OE1 OE2 OD1 OD2 OG OG1 SE SG) or name O"
|
101 |
+
)
|
102 |
+
hbondacceptor = hbondacceptor.reshape(metalcoordination.shape[0], 1)
|
103 |
+
|
104 |
+
hbonddonor = prot.atomselect(
|
105 |
+
"(resname ASN GLN ASH GLH TRP MSE SER THR MET CYS and name ND2 NE2 NE1 SG SE OG OG1) or name N"
|
106 |
+
)
|
107 |
+
hbonddonor = hbonddonor.reshape(metalcoordination.shape[0], 1)
|
108 |
+
|
109 |
+
positive = prot.atomselect(
|
110 |
+
"resname LYS ARG HIS HIE HIP HID and name NZ NH1 NH2 ND1 NE2 NE"
|
111 |
+
)
|
112 |
+
positive = positive.reshape(positive.shape[0], 1)
|
113 |
+
|
114 |
+
negative = prot.atomselect("(resname ASP GLU ASH GLH and name OD1 OD2 OE1 OE2)")
|
115 |
+
negative = negative.reshape(negative.shape[0], 1)
|
116 |
+
|
117 |
+
occupancy = prot.atomselect("protein and not hydrogen")
|
118 |
+
occupancy = occupancy.reshape(occupancy.shape[0], 1)
|
119 |
+
userchannels = np.hstack(
|
120 |
+
[
|
121 |
+
hydrophobic,
|
122 |
+
aromatic,
|
123 |
+
metalcoordination,
|
124 |
+
hbondacceptor,
|
125 |
+
hbonddonor,
|
126 |
+
positive,
|
127 |
+
negative,
|
128 |
+
occupancy,
|
129 |
+
]
|
130 |
+
)
|
131 |
+
prot_vox, prot_centers, prot_N = getVoxelDescriptors(
|
132 |
+
prot,
|
133 |
+
center=c,
|
134 |
+
userchannels=userchannels,
|
135 |
+
boxsize=size,
|
136 |
+
voxelsize=0.5,
|
137 |
+
validitychecks=False,
|
138 |
+
)
|
139 |
+
except:
|
140 |
+
raise VoxelizationError(f"voxelization of {id} failed")
|
141 |
+
nchannels = prot_vox.shape[1]
|
142 |
+
prot_vox_t = (
|
143 |
+
prot_vox.transpose()
|
144 |
+
.reshape([1, nchannels, prot_N[0], prot_N[1], prot_N[2]])
|
145 |
+
.copy()
|
146 |
+
)
|
147 |
+
|
148 |
+
voxels = torch.from_numpy(prot_vox_t)
|
149 |
+
return (voxels, prot_centers, prot_N, prot.copy())
|
150 |
+
|
151 |
+
|
152 |
+
def processStructures(pdb_file, resids, clean=True):
|
153 |
+
"""Process a pdb file and return a list of voxelized boxes centered on the residues
|
154 |
+
|
155 |
+
Parameters
|
156 |
+
----------
|
157 |
+
pdb_file : str
|
158 |
+
Path to pdb file
|
159 |
+
resids : list
|
160 |
+
List of resids to center the voxels on
|
161 |
+
clean : bool
|
162 |
+
If True, remove all non-protein residues from the pdb file
|
163 |
+
|
164 |
+
Returns
|
165 |
+
-------
|
166 |
+
voxels : torch.Tensor
|
167 |
+
Voxelized boxes with 8 channels (N, 8,32,32,32)
|
168 |
+
prot_centers_list : list
|
169 |
+
List of the centers of the voxels (N*32**32*32,3)
|
170 |
+
prot_n_list : list
|
171 |
+
List of the number of voxels in each box (N,3)
|
172 |
+
envs: list
|
173 |
+
List of tuples (prot, idx) (N)
|
174 |
+
"""
|
175 |
+
|
176 |
+
start_time_processing = time.time()
|
177 |
+
|
178 |
+
# load molecule using MoleculeKit
|
179 |
+
try:
|
180 |
+
prot = Molecule(pdb_file)
|
181 |
+
except:
|
182 |
+
raise IOError("could not read pdbfile")
|
183 |
+
|
184 |
+
if clean:
|
185 |
+
prot.filter("protein and not hydrogen")
|
186 |
+
|
187 |
+
environments = []
|
188 |
+
for idx in resids:
|
189 |
+
try:
|
190 |
+
environments.append((prot.copy(), idx))
|
191 |
+
except:
|
192 |
+
print("ignoring " + idx)
|
193 |
+
|
194 |
+
prot_centers_list = []
|
195 |
+
prot_n_list = []
|
196 |
+
envs = []
|
197 |
+
|
198 |
+
results = [voxelize_single_notcentered(x) for x in environments]
|
199 |
+
|
200 |
+
voxels = torch.empty(len(results), 8, 32, 32, 32, device="cuda")
|
201 |
+
|
202 |
+
vox_env, prot_centers_list, prot_n_list, envs = zip(*results)
|
203 |
+
|
204 |
+
for i, vox_env in enumerate(vox_env):
|
205 |
+
voxels[i] = vox_env
|
206 |
+
|
207 |
+
print(f"Voxelization took {time.time() - start_time_processing:.3f} seconds ")
|
208 |
+
|
209 |
+
return voxels, prot_centers_list, prot_n_list, envs
|
weights/Metal3D.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8f23dc5c28ffb03a77756f7e2613cb6f4b92425c1db87b422295ddaab1204515
|
3 |
+
size 7872827
|
weights/metal_0.5A_v3_d0.2_16Abox.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b5a1b0c5ea6c5dcdedfae4e24b8461da107ea78c8b96c7db8f44db532c87246f
|
3 |
+
size 13815931
|