Simon Duerr commited on
Commit
b4346be
·
1 Parent(s): 3897ec7
app.py CHANGED
@@ -1,19 +1,253 @@
1
  import gradio as gr
2
 
3
- def update(name):
4
- return f"Welcome to Gradio, {name}!"
 
 
5
 
6
- demo = gr.Blocks()
 
 
 
7
 
8
- with demo:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  gr.Markdown("# Metal3D")
 
 
 
 
 
 
10
  with gr.Group():
11
- inp = gr.Textbox(placeholder="2CBA", label="PDB or Uniprot code")
12
- file = gr.File(file_count=1, label="Upload a PDB file")
13
-
14
- btn = gr.Button("Run Metal3D")
15
- out = gr.Textbox()
 
 
 
 
 
 
 
 
16
  mol = gr.HTML()
17
- btn.click(fn=update, inputs=inp, outputs=out)
 
 
18
 
19
- demo.launch()
 
1
  import gradio as gr
2
 
3
+ import urllib
4
+ import re
5
+ import sys
6
+ import warnings
7
 
8
+ import torch
9
+ import torch.nn as nn
10
+ import ipywidgets as widgets
11
+ from ipywidgets import interact, fixed
12
 
13
+ from utils.helpers import *
14
+ from utils.voxelization import processStructures
15
+ from utils.model import Model
16
+ import numpy as np
17
+
18
+ import os
19
+
20
+ def update(inp, file, mode):
21
+ try:
22
+ pdb_file = file.name
23
+ except:
24
+ print("using pdbfile")
25
+
26
+ try:
27
+ pdb_file = inp
28
+ if (
29
+ re.match(
30
+ "[OPQ][0-9][A-Z0-9]{3}[0-9]|[A-NR-Z][0-9]([A-Z][A-Z0-9]{2}[0-9]){1,2}",
31
+ pdb_file,
32
+ ).group()
33
+ == pdb_file
34
+ ):
35
+ urllib.request.urlretrieve(
36
+ f"https://alphafold.ebi.ac.uk/files/AF-{pdb_file}-F1-model_v2.pdb",
37
+ f"files/{pdb_file}.pdb",
38
+ )
39
+ except AttributeError:
40
+ if len(inp) == 4:
41
+ pdb_file = inp
42
+ urllib.request.urlretrieve(
43
+ f"http://files.rcsb.org/download/{pdb_file.lower()}.pdb1",
44
+ f"files/{pdb_file}.pdb",
45
+ )
46
+ else:
47
+ return "pdb code must be 4 letters or Uniprot code does not match", ""
48
+
49
+ if mode == "All residues":
50
+ ids = get_all_protein_resids(
51
+ f"files/{pdb_file}.pdb",
52
+ )
53
+ else:
54
+ ids = get_all_metalbinding_resids(f"files/{pdb_file}.pdb")
55
+
56
+ voxels, prot_centers, prot_N, prots = processStructures(pdb_file, ids)
57
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
58
+ voxels.to(device)
59
+ print(voxels.shape)
60
+ model = Model()
61
+ model.to(device)
62
+ model.load_state_dict(torch.load("weights/metal_0.5A_v3_d0.2_16Abox.pth"))
63
+ model.eval()
64
+ with warnings.catch_warnings():
65
+ warnings.filterwarnings("ignore")
66
+ output = model(voxels)
67
+ print(output.shape)
68
+ prot_v = np.vstack(prot_centers)
69
+ output_v = output.flatten().cpu().detach().numpy()
70
+ bb = get_bb(prot_v)
71
+ gridres = 0.5
72
+ grid, box_N = create_grid_fromBB(bb, voxelSize=gridres)
73
+ probability_values = get_probability_mean(grid, prot_v, output_v)
74
+ print(probability_values.shape)
75
+ write_cubefile(
76
+ bb,
77
+ probability_values,
78
+ box_N,
79
+ outname=f"output/metal_{pdb_file}.cube",
80
+ gridres=gridres,
81
+ )
82
+ message = find_unique_sites(
83
+ probability_values,
84
+ grid,
85
+ writeprobes=True,
86
+ probefile=f"output/probes_{pdb_file}.pdb",
87
+ threshold=7,
88
+ p=0.15,
89
+ )
90
+
91
+ return message, molecule(
92
+ f"files/{pdb_file}.pdb",
93
+ f"output/probes_{pdb_file}.pdb",
94
+ f"output/metal_{pdb_file}.cube",
95
+ )
96
+
97
+
98
+ def test():
99
+ x = """<!DOCTYPE html>
100
+ <html>
101
+ <head>
102
+ <meta http-equiv="content-type" content="text/html; charset=UTF-8" />
103
+ </head>
104
+ <body>
105
+ <script src="https://3Dmol.org/build/3Dmol-min.js" async></script> <div style="height: 400px; width: 400px; position: relative;" class="viewer_3Dmoljs" data-pdb="2POR" data-backgroundcolor="0xffffff" data-style="stick" ></div>
106
+ </body></html>"""
107
+ return f"""<iframe style="width: 100%; height: 480px" name="result" allow="midi; geolocation; microphone; camera;
108
+ display-capture; encrypted-media;" sandbox="allow-modals allow-forms
109
+ allow-scripts allow-same-origin allow-popups
110
+ allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
111
+ allowpaymentrequest="" frameborder="0" srcdoc='{x}'></iframe>"""
112
+
113
+
114
+ def read_mol(molpath):
115
+ with open(molpath, "r") as fp:
116
+ lines = fp.readlines()
117
+ mol = ""
118
+ for l in lines:
119
+ mol += l
120
+ return mol
121
+
122
+
123
+ def molecule(pdb, probes, cube):
124
+ mol = read_mol(pdb)
125
+ probes = read_mol(probes)
126
+ cubefile = read_mol(cube)
127
+ x = (
128
+ """<!DOCTYPE html>
129
+ <html>
130
+ <head>
131
+ <meta http-equiv="content-type" content="text/html; charset=UTF-8" />
132
+ <style>
133
+ body{
134
+ font-family:sans-serif
135
+ }
136
+ .mol-container {
137
+ width: 100%;
138
+ height: 400px;
139
+ position: relative;
140
+ }
141
+ .slider{
142
+ width:80%;
143
+ margin:0 auto
144
+ }
145
+ .slidercontainer{
146
+ display:flex;
147
+ }
148
+ .slidercontainer > * + * {
149
+ margin-left: 0.5rem;
150
+ }
151
+ #isovalue{
152
+ text-align:right}
153
+ </style>
154
+ <script src="https://3Dmol.csb.pitt.edu/build/3Dmol-min.js"></script>
155
+ <script src="https://cdnjs.cloudflare.com/ajax/libs/rangeslider.js/2.3.3/rangeslider.min.js" integrity="sha512-BUlWdwDeJo24GIubM+z40xcj/pjw7RuULBkxOTc+0L9BaGwZPwiwtbiSVzv31qR7TWx7bs6OPTE5IyfLOorboQ==" crossorigin="anonymous" referrerpolicy="no-referrer"></script>
156
+ </head>
157
+ <body>
158
+ <div class="slidercontainer">
159
+ <span>Isovalue </span>
160
+ <span id="isovalue">0.5</span>
161
+ <input class="slider" type="range" id="rangeslider" min="0" max="1" step="0.025" value=0.5>
162
+ </div>
163
+
164
+ <div id="container" class="mol-container"></div>
165
+ <script>
166
+ let viewer = null;
167
+ let voldata = null;
168
+ $(document).ready(function () {
169
+ let element = $("#container");
170
+ let config = { backgroundColor: "white" };
171
+ viewer = $3Dmol.createViewer( element, config );
172
+ viewer.ui.initiateUI();
173
+ let data = `"""
174
+ + mol
175
+ + """`
176
+ viewer.addModel( data, "pdb" );
177
+
178
+ let cubefile = `"""
179
+ + cubefile
180
+ + """`
181
+ voldata = new $3Dmol.VolumeData(cubefile, "cube");
182
+ viewer.addIsosurface(voldata, { isoval: 0.7 , color: "blue", alpha: 0.85, smoothness: 1 });
183
+ viewer.getModel(0).setStyle({}, {cartoon: {color: "grayCarbon"}});
184
+ let probes =`"""
185
+ + probes
186
+ + """`
187
+ viewer.addModel(probes, "pdb");
188
+ viewer.getModel(1).setStyle({ "resn": "ZN" }, { "sphere": { }});
189
+ viewer.getModel(1).setHoverable({}, true,
190
+ function (atom, viewer, event, container) {
191
+ if (!atom.label) {
192
+ atom.label = viewer.addLabel("ZN p=" + atom.pdbline.substring(55, 60), { position: atom, backgroundColor: "mintcream", fontColor: "black" });
193
+ }
194
+ },
195
+ function (atom, viewer) {
196
+ if (atom.label) {
197
+ viewer.removeLabel(atom.label);
198
+ delete atom.label;
199
+ }
200
+ }
201
+ );
202
+ viewer.zoomTo();
203
+ viewer.render();
204
+ viewer.zoom(0.8, 2000);
205
+ });
206
+ </script>
207
+ <script>
208
+ $("#rangeslider").rangeslider().on("change", function (el) {
209
+ isoval = parseFloat(el.target.value);
210
+ $("#isovalue").text(el.target.value)
211
+ viewer.addIsosurface(voldata, { isoval: parseFloat(el.target.value), color: "blue", alpha: 0.85, smoothness: 1 });
212
+ viewer.render();
213
+ });
214
+ </script>
215
+ </body></html>"""
216
+ )
217
+
218
+ return f"""<iframe style="width: 100%; height: 480px" name="result" allow="midi; geolocation; microphone; camera;
219
+ display-capture; encrypted-media;" sandbox="allow-modals allow-forms
220
+ allow-scripts allow-same-origin allow-popups
221
+ allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
222
+ allowpaymentrequest="" frameborder="0" srcdoc='{x}'></iframe>"""
223
+
224
+
225
+ metal3d = gr.Blocks()
226
+
227
+ with metal3d:
228
  gr.Markdown("# Metal3D")
229
+ gr.Markdown(
230
+ """
231
+ Details about implementation and code available here:
232
+ >Duerr, Levy and Roethlisberger, Predicting zinc ion location using deep learning, BioRxiv, 2022 "
233
+ """
234
+ )
235
  with gr.Group():
236
+ inp = gr.Textbox(
237
+ placeholder="PDB Code or Uniprot identifier", label="Input molecule"
238
+ )
239
+ gr.Markdown("or upload a file")
240
+ file = gr.File(file_count="single", type="file")
241
+ mode = gr.Radio(
242
+ ["All metalbinding residues (ASP, CYS, GLU, HIS)", "All residues"],
243
+ label="Residues to use for prediction",
244
+ )
245
+ btn = gr.Button("Run")
246
+
247
+ gr.Markdown("# Output")
248
+ out = gr.Textbox(label="status")
249
  mol = gr.HTML()
250
+ btn.click(fn=update, inputs=[inp, file, mode], outputs=[out, mol])
251
+
252
+ metal3d.launch()
253
 
 
utils/__pycache__/helpers.cpython-38.pyc ADDED
Binary file (7.65 kB). View file
 
utils/__pycache__/model.cpython-38.pyc ADDED
Binary file (1.41 kB). View file
 
utils/__pycache__/voxelization.cpython-38.pyc ADDED
Binary file (4.99 kB). View file
 
utils/helpers.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import multiprocessing
3
+ from multiprocessing import Pool
4
+ from turtle import width
5
+
6
+ import numpy as np
7
+
8
+ from moleculekit.molecule import Molecule
9
+ from scipy.spatial import KDTree
10
+ from sklearn.cluster import AgglomerativeClustering
11
+
12
+
13
+ def create_grid_fromBB(boundingBox, voxelSize=1):
14
+ """Create a grid from a bounding box.
15
+
16
+ Parameters
17
+ ----------
18
+ boundingBox : list
19
+ List of the form [xmin, xmax, ymin, ymax, zmin, zmax]
20
+ voxelSize : float
21
+ Size of the voxels in Angstrom
22
+
23
+ Returns
24
+ -------
25
+ grid : numpy.ndarray
26
+ Grid of shape (nx, ny, nz)
27
+ box_N : numpy.ndarray
28
+ Number of voxels in each dimension
29
+
30
+ """
31
+ # increase grid by 0.5 to sample everything
32
+ xrange = np.arange(boundingBox[0][0], boundingBox[1][0] + 0.5, step=voxelSize)
33
+ yrange = np.arange(boundingBox[0][1], boundingBox[1][1] + 0.5, step=voxelSize)
34
+ zrange = np.arange(boundingBox[0][2], boundingBox[1][2] + 0.5, step=voxelSize)
35
+
36
+ gridpoints = np.zeros((xrange.shape[0] * yrange.shape[0] * zrange.shape[0], 3))
37
+ i = 0
38
+ for x in xrange:
39
+ for y in yrange:
40
+ for z in zrange:
41
+ gridpoints[i][0] = x
42
+ gridpoints[i][1] = y
43
+ gridpoints[i][2] = z
44
+ i += 1
45
+ return gridpoints, (xrange.shape[0], yrange.shape[0], zrange.shape[0])
46
+
47
+
48
+ def get_bb(points):
49
+ """Return bounding box from a set of points (N,3)
50
+
51
+ Parameters
52
+ ----------
53
+ points : numpy.ndarray
54
+ Set of points (N,3)
55
+
56
+ Returns
57
+ -------
58
+ boundingBox : list
59
+ List of the form [xmin, xmax, ymin, ymax, zmin, zmax]
60
+
61
+ """
62
+ minx = np.min(points[:, 0])
63
+ maxx = np.max(points[:, 0])
64
+
65
+ miny = np.min(points[:, 1])
66
+ maxy = np.max(points[:, 1])
67
+
68
+ minz = np.min(points[:, 2])
69
+ maxz = np.max(points[:, 2])
70
+ bb = [[minx, miny, minz], [maxx, maxy, maxz]]
71
+ return bb
72
+
73
+
74
+ def get_all_protein_resids(pdb_file):
75
+ """Return all protein residues from a pdb file
76
+
77
+ Parameters
78
+ ----------
79
+ pdb_file : str
80
+ Path to pdb file
81
+
82
+ Returns
83
+ -------
84
+ resids : numpy.ndarray
85
+ Array of protein resids old -> new
86
+
87
+ """
88
+ try:
89
+ prot = Molecule(pdb_file)
90
+ except:
91
+ exit("could not read file")
92
+ prot.filter("protein")
93
+ return prot.get("index", sel="name CA")
94
+
95
+
96
+ def get_all_metalbinding_resids(pdb_file):
97
+ """Return all metal binding residues from a pdb file
98
+
99
+ Parameters
100
+ ----------
101
+ pdb_file : str
102
+ Path to pdb file
103
+
104
+ Returns
105
+ -------
106
+ resids : numpy.ndarray
107
+ id of resids that are metal binding
108
+
109
+ """
110
+
111
+ try:
112
+ prot = Molecule(pdb_file)
113
+ except:
114
+ exit("could not read file")
115
+ prot.filter("protein")
116
+ return prot.get(
117
+ "index",
118
+ sel="name CA and resname HIS HID HIE HIP CYS CYX GLU GLH GLN ASP ASH ASN GLN MET",
119
+ )
120
+
121
+
122
+ def compute_average_p_fast(point, cutoff=1):
123
+ """Using KDTree find the closest gridpoints
124
+
125
+ Parameters
126
+ ----------
127
+ point : numpy.ndarray
128
+ Point of shape (3,)
129
+ cutoff : float
130
+ Cutoff distance in Angstrom
131
+
132
+ Returns
133
+ -------
134
+ average_p : numpy.ndarray
135
+ Average probability of shape (1,)"""
136
+ p = 0
137
+ nearest_neighbors, indices = tree.query(
138
+ point, k=15, distance_upper_bound=cutoff, workers=1
139
+ )
140
+ if np.min(nearest_neighbors) != np.inf:
141
+ p = np.mean(output_v[indices[nearest_neighbors != np.inf]])
142
+ return p
143
+
144
+
145
+ def get_probability_mean(grid, prot_centers, pvalues):
146
+ """Compute the mean probability of all gridpoints from the globalgrid based on the individual boxes
147
+
148
+ Parameters
149
+ ----------
150
+ grid : numpy.ndarray
151
+ Grid of shape (nx, ny, nz)
152
+ prot_centers : numpy.ndarray
153
+ Protein centers of shape (N,3)
154
+ pvalues : numpy.ndarray
155
+ Probability values of shape (N,1)
156
+
157
+ Returns
158
+ -------
159
+ mean_p : numpy.ndarray
160
+ Mean probability over grid of shape (nx, ny, nz)
161
+ """
162
+ global output_v
163
+ output_v = pvalues
164
+ global prot_v
165
+ prot_v = prot_centers
166
+ cpuCount = multiprocessing.cpu_count()
167
+
168
+ global tree
169
+ tree = KDTree(prot_v)
170
+ p = Pool(cpuCount)
171
+ results = p.map(compute_average_p_fast, grid)
172
+ return np.array(results)
173
+
174
+
175
+ def write_cubefile(bb, pvalues, box_N, outname="Metal3D_pmap.cube", gridres=1):
176
+ """Write a cube file from a probability map
177
+ The cube specification from gaussian is used, distance are converted to bohr
178
+
179
+ Parameters
180
+ ----------
181
+ bb : list
182
+ List of the form [xmin, xmax, ymin, ymax, zmin, zmax]
183
+ pvalues : numpy.ndarray
184
+ Probability values of shape (nx, ny, nz)
185
+ box_N : tuple
186
+ Number of voxels in each dimension
187
+ outname : str
188
+ Name of the output file
189
+ gridres:float
190
+ Resolution of the grid used for writing the voxels
191
+
192
+ """
193
+
194
+ with open(outname, "w") as cube:
195
+ cube.write(" Metal3D Cube File\n")
196
+ cube.write(" Outer Loop: X, Middle Loop y, inner Loop z\n")
197
+
198
+ angstromToBohr = 1.89
199
+ cube.write(
200
+ f" 1 {bb[0][0]*angstromToBohr: .6f} {bb[0][1]*angstromToBohr: .6f} {bb[0][2]*angstromToBohr: .6f}\n"
201
+ )
202
+ cube.write(
203
+ f"{str(box_N[0]).rjust(5)} {1.890000*gridres:.9f} 0.000000 0.000000\n"
204
+ )
205
+ cube.write(
206
+ f"{str(box_N[1]).rjust(5)} 0.000000 {1.890000*gridres:.9f} 0.000000\n"
207
+ )
208
+ cube.write(
209
+ f"{str(box_N[2]).rjust(5)} 0.000000 0.000000 {1.890000*gridres:.9f}\n"
210
+ )
211
+ cube.write(" 1 1.000000 0.000000 0.000000 0.000000\n")
212
+
213
+ o = pvalues.reshape(box_N)
214
+ for x in range(box_N[0]):
215
+ for y in range(box_N[1]):
216
+ for z in range(box_N[2]):
217
+ cube.write(f" {o[x][y][z]: .5E}")
218
+ if z % 6 == 5:
219
+ cube.write("\n")
220
+ cube.write("\n")
221
+
222
+
223
+ def find_unique_sites(
224
+ pvalues, grid, writeprobes=False, probefile="probes.pdb", threshold=5, p=0.75
225
+ ):
226
+ """The probability voxels are points and the voxel clouds may contain multiple metals
227
+ This function finds the unique sites and returns the coordinates of the unique sites with the highest p for each cluster.
228
+ It uses the AgglomerativeClustering algorithm to find the unique sites.
229
+ The threshold is the maximum distance between two points in the same cluster it can be changed to get more metal points.
230
+
231
+ Parameters
232
+ ----------
233
+ pvalues : numpy.ndarray
234
+ Probability values of shape (N, 1)
235
+ grid : numpy.ndarray
236
+ Grid of shape (N, 3)
237
+ writeprobes : bool
238
+ If True, write the probes to a pdb file
239
+ probefile : str
240
+ Name of the output file
241
+ threshold : float
242
+ Maximum distance between two points in the same cluster
243
+ p : float
244
+ Minimum probability of a point to be considered a unique site
245
+
246
+ """
247
+
248
+ points = grid[pvalues > p]
249
+ point_p = pvalues[pvalues > p]
250
+ if len(points) == 0:
251
+ return "no metals found"
252
+ clustering = AgglomerativeClustering(
253
+ n_clusters=None, linkage="complete", distance_threshold=threshold
254
+ ).fit(points)
255
+
256
+ message = f"min metal p={p}, n(metals) found: {clustering.n_clusters_}"
257
+
258
+ sites = []
259
+ for i in range(clustering.n_clusters_):
260
+ c_points = points[clustering.labels_ == i]
261
+ c_points_p = point_p[clustering.labels_ == i]
262
+
263
+ position = c_points[np.argmax(c_points_p)]
264
+ sites.append((position, np.max(c_points_p)))
265
+ if writeprobes:
266
+ print(f"writing probes to {probefile}")
267
+ with open(probefile, "w") as f:
268
+ for i, site in enumerate(sites):
269
+ f.write(
270
+ f"HETATM {i+1:3} ZN ZN A {i+1:3} {site[0][0]: 8.3f}{site[0][1]: 8.3f}{site[0][2]: 8.3f} {site[1]:.2f} 0.0 ZN2+\n"
271
+ )
272
+ return message
utils/model.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+
8
+ class Model(nn.Module):
9
+ """Model with same padding
10
+ Conv5 uses a large filter size to aggregate the features from the whole box"""
11
+
12
+ def __init__(self):
13
+ super(Model, self).__init__()
14
+ self.conv1 = nn.Conv3d(8, 32, 3, padding="same")
15
+ self.conv2 = nn.Conv3d(32, 64, 3, padding="same")
16
+ self.conv3 = nn.Conv3d(64, 80, 3, padding="same")
17
+ self.conv4 = nn.Conv3d(80, 20, 3, padding="same")
18
+ self.conv5 = nn.Conv3d(20, 20, 20, padding="same")
19
+ self.conv6 = nn.Conv3d(20, 16, 3, padding="same")
20
+ self.conv7 = nn.Conv3d(16, 1, 3, padding="same")
21
+ self.dropout1 = nn.Dropout(0.2)
22
+
23
+ def forward(self, x):
24
+ x = self.conv1(x)
25
+ x = F.relu(x)
26
+ x = self.conv2(x)
27
+ x = F.relu(x)
28
+ x = self.conv3(x)
29
+ x = F.relu(x)
30
+
31
+ x = self.conv4(x)
32
+ x = F.relu(x)
33
+
34
+ x = self.conv5(x)
35
+ x = F.relu(x)
36
+ x = self.dropout1(x)
37
+ x = self.conv6(x)
38
+ x = F.relu(x)
39
+
40
+ x = self.conv7(x)
41
+ x = torch.sigmoid(x)
42
+ return x
utils/voxelization.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import multiprocessing
3
+ from multiprocessing import Pool
4
+
5
+ import torch
6
+ import numpy as np
7
+
8
+ from moleculekit.molecule import Molecule
9
+ from moleculekit.tools.voxeldescriptors import getVoxelDescriptors
10
+ from moleculekit.tools.atomtyper import prepareProteinForAtomtyping
11
+ from moleculekit.tools.preparation import systemPrepare
12
+
13
+
14
+ class AtomtypingError(Exception):
15
+ pass
16
+
17
+
18
+ class StructureCleaningError(Exception):
19
+ pass
20
+
21
+
22
+ class ProteinPrepareError(Exception):
23
+ pass
24
+
25
+
26
+ class VoxelizationError(Exception):
27
+ pass
28
+
29
+
30
+ metal_atypes = (
31
+ "MG",
32
+ "ZN",
33
+ "MN",
34
+ "CA",
35
+ "FE",
36
+ "HG",
37
+ "CD",
38
+ "NI",
39
+ "CO",
40
+ "CU",
41
+ "K",
42
+ "LI",
43
+ "Mg",
44
+ "Zn",
45
+ "Mn",
46
+ "Ca",
47
+ "Fe",
48
+ "Hg",
49
+ "Cd",
50
+ "Ni",
51
+ "Co",
52
+ "Cu",
53
+ "Li",
54
+ )
55
+
56
+
57
+ def voxelize_single_notcentered(env):
58
+ """voxelize 1 structure, executed on a single CPU
59
+ Using 7 of the 8 channels supplied by moleculekit(excluding metals)
60
+ Additionally it uses all the metalbinding residues as channel
61
+
62
+ Parameters
63
+ ----------
64
+ env : tuple
65
+ Tuple of the form (prot, idx)
66
+
67
+ Returns
68
+ -------
69
+ voxels : torch.tensor
70
+ Voxelized structure with 8 channels (8,20,20,20)
71
+ prot_centers : list
72
+ List of the centers of the voxels (20x20x20,3)
73
+ prot_n : list
74
+ List of the number of voxels in each voxel (20x20x20)
75
+ prot : moleculekit.Molecule
76
+ Moleculekit molecule
77
+ """
78
+ prot, id = env
79
+
80
+ c = prot.get("coords", sel=f"index {id} and name CA")
81
+
82
+ size = [16, 16, 16] # size of box
83
+ voxels = torch.zeros(8, 32, 32, 32)
84
+
85
+ try:
86
+ hydrophobic = prot.atomselect("element C")
87
+ hydrophobic = hydrophobic.reshape(hydrophobic.shape[0], 1)
88
+
89
+ aromatic = prot.atomselect(
90
+ "resname HIS HIE HIP HID TRP TYR PHE and sidechain and not name CB and not hydrogen"
91
+ )
92
+ aromatic = aromatic.reshape(aromatic.shape[0], 1)
93
+
94
+ metalcoordination = prot.atomselect(
95
+ "(name ND1 NE2 SG OE1 OE2 OD2) or (protein and name O N)"
96
+ )
97
+ metalcoordination = metalcoordination.reshape(metalcoordination.shape[0], 1)
98
+
99
+ hbondacceptor = prot.atomselect(
100
+ "(resname ASP GLU HIS HIE HIP HID SER THR MSE CYS MET and name ND2 NE2 OE1 OE2 OD1 OD2 OG OG1 SE SG) or name O"
101
+ )
102
+ hbondacceptor = hbondacceptor.reshape(metalcoordination.shape[0], 1)
103
+
104
+ hbonddonor = prot.atomselect(
105
+ "(resname ASN GLN ASH GLH TRP MSE SER THR MET CYS and name ND2 NE2 NE1 SG SE OG OG1) or name N"
106
+ )
107
+ hbonddonor = hbonddonor.reshape(metalcoordination.shape[0], 1)
108
+
109
+ positive = prot.atomselect(
110
+ "resname LYS ARG HIS HIE HIP HID and name NZ NH1 NH2 ND1 NE2 NE"
111
+ )
112
+ positive = positive.reshape(positive.shape[0], 1)
113
+
114
+ negative = prot.atomselect("(resname ASP GLU ASH GLH and name OD1 OD2 OE1 OE2)")
115
+ negative = negative.reshape(negative.shape[0], 1)
116
+
117
+ occupancy = prot.atomselect("protein and not hydrogen")
118
+ occupancy = occupancy.reshape(occupancy.shape[0], 1)
119
+ userchannels = np.hstack(
120
+ [
121
+ hydrophobic,
122
+ aromatic,
123
+ metalcoordination,
124
+ hbondacceptor,
125
+ hbonddonor,
126
+ positive,
127
+ negative,
128
+ occupancy,
129
+ ]
130
+ )
131
+ prot_vox, prot_centers, prot_N = getVoxelDescriptors(
132
+ prot,
133
+ center=c,
134
+ userchannels=userchannels,
135
+ boxsize=size,
136
+ voxelsize=0.5,
137
+ validitychecks=False,
138
+ )
139
+ except:
140
+ raise VoxelizationError(f"voxelization of {id} failed")
141
+ nchannels = prot_vox.shape[1]
142
+ prot_vox_t = (
143
+ prot_vox.transpose()
144
+ .reshape([1, nchannels, prot_N[0], prot_N[1], prot_N[2]])
145
+ .copy()
146
+ )
147
+
148
+ voxels = torch.from_numpy(prot_vox_t)
149
+ return (voxels, prot_centers, prot_N, prot.copy())
150
+
151
+
152
+ def processStructures(pdb_file, resids, clean=True):
153
+ """Process a pdb file and return a list of voxelized boxes centered on the residues
154
+
155
+ Parameters
156
+ ----------
157
+ pdb_file : str
158
+ Path to pdb file
159
+ resids : list
160
+ List of resids to center the voxels on
161
+ clean : bool
162
+ If True, remove all non-protein residues from the pdb file
163
+
164
+ Returns
165
+ -------
166
+ voxels : torch.Tensor
167
+ Voxelized boxes with 8 channels (N, 8,32,32,32)
168
+ prot_centers_list : list
169
+ List of the centers of the voxels (N*32**32*32,3)
170
+ prot_n_list : list
171
+ List of the number of voxels in each box (N,3)
172
+ envs: list
173
+ List of tuples (prot, idx) (N)
174
+ """
175
+
176
+ start_time_processing = time.time()
177
+
178
+ # load molecule using MoleculeKit
179
+ try:
180
+ prot = Molecule(pdb_file)
181
+ except:
182
+ raise IOError("could not read pdbfile")
183
+
184
+ if clean:
185
+ prot.filter("protein and not hydrogen")
186
+
187
+ environments = []
188
+ for idx in resids:
189
+ try:
190
+ environments.append((prot.copy(), idx))
191
+ except:
192
+ print("ignoring " + idx)
193
+
194
+ prot_centers_list = []
195
+ prot_n_list = []
196
+ envs = []
197
+
198
+ results = [voxelize_single_notcentered(x) for x in environments]
199
+
200
+ voxels = torch.empty(len(results), 8, 32, 32, 32, device="cuda")
201
+
202
+ vox_env, prot_centers_list, prot_n_list, envs = zip(*results)
203
+
204
+ for i, vox_env in enumerate(vox_env):
205
+ voxels[i] = vox_env
206
+
207
+ print(f"Voxelization took {time.time() - start_time_processing:.3f} seconds ")
208
+
209
+ return voxels, prot_centers_list, prot_n_list, envs
weights/Metal3D.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8f23dc5c28ffb03a77756f7e2613cb6f4b92425c1db87b422295ddaab1204515
3
+ size 7872827
weights/metal_0.5A_v3_d0.2_16Abox.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b5a1b0c5ea6c5dcdedfae4e24b8461da107ea78c8b96c7db8f44db532c87246f
3
+ size 13815931