woshixuhao commited on
Commit
7daeb71
1 Parent(s): c1888c9

Upload 3 files

Browse files
Files changed (3) hide show
  1. GeoGNN_model.pth +3 -0
  2. app.py +1610 -0
  3. column_descriptor.npy +3 -0
GeoGNN_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3ff5277fd1a9269b3166882941438f3f68d43b2a8eefedf1d4b61b2801a66bf8
3
+ size 3139843
app.py ADDED
@@ -0,0 +1,1610 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch_geometric.nn import MessagePassing
3
+ from compound_tools import *
4
+ from rdkit.Chem import Descriptors
5
+ from torch_geometric.data import Data
6
+ import argparse
7
+ import warnings
8
+ from rdkit.Chem.Descriptors import rdMolDescriptors
9
+ import pandas as pd
10
+ import os
11
+ from mordred import Calculator, descriptors, is_missing
12
+ from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool, GlobalAttention, Set2Set
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+ import numpy as np
16
+ from rdkit import Chem
17
+ from rdkit.Chem import AllChem
18
+ from rdkit.Chem import rdchem
19
+ import gradio as gr
20
+ DAY_LIGHT_FG_SMARTS_LIST = [
21
+ # C
22
+ "[CX4]",
23
+ "[$([CX2](=C)=C)]",
24
+ "[$([CX3]=[CX3])]",
25
+ "[$([CX2]#C)]",
26
+ # C & O
27
+ "[CX3]=[OX1]",
28
+ "[$([CX3]=[OX1]),$([CX3+]-[OX1-])]",
29
+ "[CX3](=[OX1])C",
30
+ "[OX1]=CN",
31
+ "[CX3](=[OX1])O",
32
+ "[CX3](=[OX1])[F,Cl,Br,I]",
33
+ "[CX3H1](=O)[#6]",
34
+ "[CX3](=[OX1])[OX2][CX3](=[OX1])",
35
+ "[NX3][CX3](=[OX1])[#6]",
36
+ "[NX3][CX3]=[NX3+]",
37
+ "[NX3,NX4+][CX3](=[OX1])[OX2,OX1-]",
38
+ "[NX3][CX3](=[OX1])[OX2H0]",
39
+ "[NX3,NX4+][CX3](=[OX1])[OX2H,OX1-]",
40
+ "[CX3](=O)[O-]",
41
+ "[CX3](=[OX1])(O)O",
42
+ "[CX3](=[OX1])([OX2])[OX2H,OX1H0-1]",
43
+ "C[OX2][CX3](=[OX1])[OX2]C",
44
+ "[CX3](=O)[OX2H1]",
45
+ "[CX3](=O)[OX1H0-,OX2H1]",
46
+ "[NX3][CX2]#[NX1]",
47
+ "[#6][CX3](=O)[OX2H0][#6]",
48
+ "[#6][CX3](=O)[#6]",
49
+ "[OD2]([#6])[#6]",
50
+ # H
51
+ "[H]",
52
+ "[!#1]",
53
+ "[H+]",
54
+ "[+H]",
55
+ "[!H]",
56
+ # N
57
+ "[NX3;H2,H1;!$(NC=O)]",
58
+ "[NX3][CX3]=[CX3]",
59
+ "[NX3;H2;!$(NC=[!#6]);!$(NC#[!#6])][#6]",
60
+ "[NX3;H2,H1;!$(NC=O)].[NX3;H2,H1;!$(NC=O)]",
61
+ "[NX3][$(C=C),$(cc)]",
62
+ "[NX3,NX4+][CX4H]([*])[CX3](=[OX1])[O,N]",
63
+ "[NX3H2,NH3X4+][CX4H]([*])[CX3](=[OX1])[NX3,NX4+][CX4H]([*])[CX3](=[OX1])[OX2H,OX1-]",
64
+ "[$([NX3H2,NX4H3+]),$([NX3H](C)(C))][CX4H]([*])[CX3](=[OX1])[OX2H,OX1-,N]",
65
+ "[CH3X4]",
66
+ "[CH2X4][CH2X4][CH2X4][NHX3][CH0X3](=[NH2X3+,NHX2+0])[NH2X3]",
67
+ "[CH2X4][CX3](=[OX1])[NX3H2]",
68
+ "[CH2X4][CX3](=[OX1])[OH0-,OH]",
69
+ "[CH2X4][SX2H,SX1H0-]",
70
+ "[CH2X4][CH2X4][CX3](=[OX1])[OH0-,OH]",
71
+ "[$([$([NX3H2,NX4H3+]),$([NX3H](C)(C))][CX4H2][CX3](=[OX1])[OX2H,OX1-,N])]",
72
+ "[CH2X4][#6X3]1:[$([#7X3H+,#7X2H0+0]:[#6X3H]:[#7X3H]),$([#7X3H])]:[#6X3H]:\
73
+ [$([#7X3H+,#7X2H0+0]:[#6X3H]:[#7X3H]),$([#7X3H])]:[#6X3H]1",
74
+ "[CHX4]([CH3X4])[CH2X4][CH3X4]",
75
+ "[CH2X4][CHX4]([CH3X4])[CH3X4]",
76
+ "[CH2X4][CH2X4][CH2X4][CH2X4][NX4+,NX3+0]",
77
+ "[CH2X4][CH2X4][SX2][CH3X4]",
78
+ "[CH2X4][cX3]1[cX3H][cX3H][cX3H][cX3H][cX3H]1",
79
+ "[$([NX3H,NX4H2+]),$([NX3](C)(C)(C))]1[CX4H]([CH2][CH2][CH2]1)[CX3](=[OX1])[OX2H,OX1-,N]",
80
+ "[CH2X4][OX2H]",
81
+ "[NX3][CX3]=[SX1]",
82
+ "[CHX4]([CH3X4])[OX2H]",
83
+ "[CH2X4][cX3]1[cX3H][nX3H][cX3]2[cX3H][cX3H][cX3H][cX3H][cX3]12",
84
+ "[CH2X4][cX3]1[cX3H][cX3H][cX3]([OHX2,OH0X1-])[cX3H][cX3H]1",
85
+ "[CHX4]([CH3X4])[CH3X4]",
86
+ "N[CX4H2][CX3](=[OX1])[O,N]",
87
+ "N1[CX4H]([CH2][CH2][CH2]1)[CX3](=[OX1])[O,N]",
88
+ "[$(*-[NX2-]-[NX2+]#[NX1]),$(*-[NX2]=[NX2+]=[NX1-])]",
89
+ "[$([NX1-]=[NX2+]=[NX1-]),$([NX1]#[NX2+]-[NX1-2])]",
90
+ "[#7]",
91
+ "[NX2]=N",
92
+ "[NX2]=[NX2]",
93
+ "[$([NX2]=[NX3+]([O-])[#6]),$([NX2]=[NX3+0](=[O])[#6])]",
94
+ "[$([#6]=[N+]=[N-]),$([#6-]-[N+]#[N])]",
95
+ "[$([nr5]:[nr5,or5,sr5]),$([nr5]:[cr5]:[nr5,or5,sr5])]",
96
+ "[NX3][NX3]",
97
+ "[NX3][NX2]=[*]",
98
+ "[CX3;$([C]([#6])[#6]),$([CH][#6])]=[NX2][#6]",
99
+ "[$([CX3]([#6])[#6]),$([CX3H][#6])]=[$([NX2][#6]),$([NX2H])]",
100
+ "[NX3+]=[CX3]",
101
+ "[CX3](=[OX1])[NX3H][CX3](=[OX1])",
102
+ "[CX3](=[OX1])[NX3H0]([#6])[CX3](=[OX1])",
103
+ "[CX3](=[OX1])[NX3H0]([NX3H0]([CX3](=[OX1]))[CX3](=[OX1]))[CX3](=[OX1])",
104
+ "[$([NX3](=[OX1])(=[OX1])O),$([NX3+]([OX1-])(=[OX1])O)]",
105
+ "[$([OX1]=[NX3](=[OX1])[OX1-]),$([OX1]=[NX3+]([OX1-])[OX1-])]",
106
+ "[NX1]#[CX2]",
107
+ "[CX1-]#[NX2+]",
108
+ "[$([NX3](=O)=O),$([NX3+](=O)[O-])][!#8]",
109
+ "[$([NX3](=O)=O),$([NX3+](=O)[O-])][!#8].[$([NX3](=O)=O),$([NX3+](=O)[O-])][!#8]",
110
+ "[NX2]=[OX1]",
111
+ "[$([#7+][OX1-]),$([#7v5]=[OX1]);!$([#7](~[O])~[O]);!$([#7]=[#7])]",
112
+ # O
113
+ "[OX2H]",
114
+ "[#6][OX2H]",
115
+ "[OX2H][CX3]=[OX1]",
116
+ "[OX2H]P",
117
+ "[OX2H][#6X3]=[#6]",
118
+ "[OX2H][cX3]:[c]",
119
+ "[OX2H][$(C=C),$(cc)]",
120
+ "[$([OH]-*=[!#6])]",
121
+ "[OX2,OX1-][OX2,OX1-]",
122
+ # P
123
+ "[$(P(=[OX1])([$([OX2H]),$([OX1-]),$([OX2]P)])([$([OX2H]),$([OX1-]),\
124
+ $([OX2]P)])[$([OX2H]),$([OX1-]),$([OX2]P)]),$([P+]([OX1-])([$([OX2H]),$([OX1-])\
125
+ ,$([OX2]P)])([$([OX2H]),$([OX1-]),$([OX2]P)])[$([OX2H]),$([OX1-]),$([OX2]P)])]",
126
+ "[$(P(=[OX1])([OX2][#6])([$([OX2H]),$([OX1-]),$([OX2][#6])])[$([OX2H]),\
127
+ $([OX1-]),$([OX2][#6]),$([OX2]P)]),$([P+]([OX1-])([OX2][#6])([$([OX2H]),$([OX1-]),\
128
+ $([OX2][#6])])[$([OX2H]),$([OX1-]),$([OX2][#6]),$([OX2]P)])]",
129
+ # S
130
+ "[S-][CX3](=S)[#6]",
131
+ "[#6X3](=[SX1])([!N])[!N]",
132
+ "[SX2]",
133
+ "[#16X2H]",
134
+ "[#16!H0]",
135
+ "[#16X2H0]",
136
+ "[#16X2H0][!#16]",
137
+ "[#16X2H0][#16X2H0]",
138
+ "[#16X2H0][!#16].[#16X2H0][!#16]",
139
+ "[$([#16X3](=[OX1])[OX2H0]),$([#16X3+]([OX1-])[OX2H0])]",
140
+ "[$([#16X3](=[OX1])[OX2H,OX1H0-]),$([#16X3+]([OX1-])[OX2H,OX1H0-])]",
141
+ "[$([#16X4](=[OX1])=[OX1]),$([#16X4+2]([OX1-])[OX1-])]",
142
+ "[$([#16X4](=[OX1])(=[OX1])([#6])[#6]),$([#16X4+2]([OX1-])([OX1-])([#6])[#6])]",
143
+ "[$([#16X4](=[OX1])(=[OX1])([#6])[OX2H,OX1H0-]),$([#16X4+2]([OX1-])([OX1-])([#6])[OX2H,OX1H0-])]",
144
+ "[$([#16X4](=[OX1])(=[OX1])([#6])[OX2H0]),$([#16X4+2]([OX1-])([OX1-])([#6])[OX2H0])]",
145
+ "[$([#16X4]([NX3])(=[OX1])(=[OX1])[#6]),$([#16X4+2]([NX3])([OX1-])([OX1-])[#6])]",
146
+ "[SX4](C)(C)(=O)=N",
147
+ "[$([SX4](=[OX1])(=[OX1])([!O])[NX3]),$([SX4+2]([OX1-])([OX1-])([!O])[NX3])]",
148
+ "[$([#16X3]=[OX1]),$([#16X3+][OX1-])]",
149
+ "[$([#16X3](=[OX1])([#6])[#6]),$([#16X3+]([OX1-])([#6])[#6])]",
150
+ "[$([#16X4](=[OX1])(=[OX1])([OX2H,OX1H0-])[OX2][#6]),$([#16X4+2]([OX1-])([OX1-])([OX2H,OX1H0-])[OX2][#6])]",
151
+ "[$([SX4](=O)(=O)(O)O),$([SX4+2]([O-])([O-])(O)O)]",
152
+ "[$([#16X4](=[OX1])(=[OX1])([OX2][#6])[OX2][#6]),$([#16X4](=[OX1])(=[OX1])([OX2][#6])[OX2][#6])]",
153
+ "[$([#16X4]([NX3])(=[OX1])(=[OX1])[OX2][#6]),$([#16X4+2]([NX3])([OX1-])([OX1-])[OX2][#6])]",
154
+ "[$([#16X4]([NX3])(=[OX1])(=[OX1])[OX2H,OX1H0-]),$([#16X4+2]([NX3])([OX1-])([OX1-])[OX2H,OX1H0-])]",
155
+ "[#16X2][OX2H,OX1H0-]",
156
+ "[#16X2][OX2H0]",
157
+ # X
158
+ "[#6][F,Cl,Br,I]",
159
+ "[F,Cl,Br,I]",
160
+ "[F,Cl,Br,I].[F,Cl,Br,I].[F,Cl,Br,I]",
161
+ ]
162
+
163
+
164
+ def get_gasteiger_partial_charges(mol, n_iter=12):
165
+ """
166
+ Calculates list of gasteiger partial charges for each atom in mol object.
167
+ Args:
168
+ mol: rdkit mol object.
169
+ n_iter(int): number of iterations. Default 12.
170
+ Returns:
171
+ list of computed partial charges for each atom.
172
+ """
173
+ Chem.rdPartialCharges.ComputeGasteigerCharges(mol, nIter=n_iter,
174
+ throwOnParamFailure=True)
175
+ partial_charges = [float(a.GetProp('_GasteigerCharge')) for a in
176
+ mol.GetAtoms()]
177
+ return partial_charges
178
+
179
+
180
+ def create_standardized_mol_id(smiles):
181
+ """
182
+ Args:
183
+ smiles: smiles sequence.
184
+ Returns:
185
+ inchi.
186
+ """
187
+ if check_smiles_validity(smiles):
188
+ # remove stereochemistry
189
+ smiles = AllChem.MolToSmiles(AllChem.MolFromSmiles(smiles),
190
+ isomericSmiles=False)
191
+ mol = AllChem.MolFromSmiles(smiles)
192
+ if not mol is None: # to catch weird issue with O=C1O[al]2oc(=O)c3ccc(cn3)c3ccccc3c3cccc(c3)c3ccccc3c3cc(C(F)(F)F)c(cc3o2)-c2ccccc2-c2cccc(c2)-c2ccccc2-c2cccnc21
193
+ if '.' in smiles: # if multiple species, pick largest molecule
194
+ mol_species_list = split_rdkit_mol_obj(mol)
195
+ largest_mol = get_largest_mol(mol_species_list)
196
+ inchi = AllChem.MolToInchi(largest_mol)
197
+ else:
198
+ inchi = AllChem.MolToInchi(mol)
199
+ return inchi
200
+ else:
201
+ return
202
+ else:
203
+ return
204
+
205
+
206
+ def check_smiles_validity(smiles):
207
+ """
208
+ Check whether the smile can't be converted to rdkit mol object.
209
+ """
210
+ try:
211
+ m = Chem.MolFromSmiles(smiles)
212
+ if m:
213
+ return True
214
+ else:
215
+ return False
216
+ except Exception as e:
217
+ return False
218
+
219
+
220
+ def split_rdkit_mol_obj(mol):
221
+ """
222
+ Split rdkit mol object containing multiple species or one species into a
223
+ list of mol objects or a list containing a single object respectively.
224
+ Args:
225
+ mol: rdkit mol object.
226
+ """
227
+ smiles = AllChem.MolToSmiles(mol, isomericSmiles=True)
228
+ smiles_list = smiles.split('.')
229
+ mol_species_list = []
230
+ for s in smiles_list:
231
+ if check_smiles_validity(s):
232
+ mol_species_list.append(AllChem.MolFromSmiles(s))
233
+ return mol_species_list
234
+
235
+
236
+ def get_largest_mol(mol_list):
237
+ """
238
+ Given a list of rdkit mol objects, returns mol object containing the
239
+ largest num of atoms. If multiple containing largest num of atoms,
240
+ picks the first one.
241
+ Args:
242
+ mol_list(list): a list of rdkit mol object.
243
+ Returns:
244
+ the largest mol.
245
+ """
246
+ num_atoms_list = [len(m.GetAtoms()) for m in mol_list]
247
+ largest_mol_idx = num_atoms_list.index(max(num_atoms_list))
248
+ return mol_list[largest_mol_idx]
249
+
250
+
251
+ def rdchem_enum_to_list(values):
252
+ """values = {0: rdkit.Chem.rdchem.ChiralType.CHI_UNSPECIFIED,
253
+ 1: rdkit.Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
254
+ 2: rdkit.Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW,
255
+ 3: rdkit.Chem.rdchem.ChiralType.CHI_OTHER}
256
+ """
257
+ return [values[i] for i in range(len(values))]
258
+
259
+
260
+ def safe_index(alist, elem):
261
+ """
262
+ Return index of element e in list l. If e is not present, return the last index
263
+ """
264
+ try:
265
+ return alist.index(elem)
266
+ except ValueError:
267
+ return len(alist) - 1
268
+
269
+
270
+ def get_atom_feature_dims(list_acquired_feature_names):
271
+ """ tbd
272
+ """
273
+ return list(map(len, [CompoundKit.atom_vocab_dict[name] for name in list_acquired_feature_names]))
274
+
275
+
276
+ def get_bond_feature_dims(list_acquired_feature_names):
277
+ """ tbd
278
+ """
279
+ list_bond_feat_dim = list(map(len, [CompoundKit.bond_vocab_dict[name] for name in list_acquired_feature_names]))
280
+ # +1 for self loop edges
281
+ return [_l + 1 for _l in list_bond_feat_dim]
282
+
283
+
284
+ class CompoundKit(object):
285
+ """
286
+ CompoundKit
287
+ """
288
+ atom_vocab_dict = {
289
+ "atomic_num": list(range(1, 119)) + ['misc'],
290
+ "chiral_tag": rdchem_enum_to_list(rdchem.ChiralType.values),
291
+ "degree": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 'misc'],
292
+ "explicit_valence": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 'misc'],
293
+ "formal_charge": [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 'misc'],
294
+ "hybridization": rdchem_enum_to_list(rdchem.HybridizationType.values),
295
+ "implicit_valence": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 'misc'],
296
+ "is_aromatic": [0, 1],
297
+ "total_numHs": [0, 1, 2, 3, 4, 5, 6, 7, 8, 'misc'],
298
+ 'num_radical_e': [0, 1, 2, 3, 4, 'misc'],
299
+ 'atom_is_in_ring': [0, 1],
300
+ 'valence_out_shell': [0, 1, 2, 3, 4, 5, 6, 7, 8, 'misc'],
301
+ 'in_num_ring_with_size3': [0, 1, 2, 3, 4, 5, 6, 7, 8, 'misc'],
302
+ 'in_num_ring_with_size4': [0, 1, 2, 3, 4, 5, 6, 7, 8, 'misc'],
303
+ 'in_num_ring_with_size5': [0, 1, 2, 3, 4, 5, 6, 7, 8, 'misc'],
304
+ 'in_num_ring_with_size6': [0, 1, 2, 3, 4, 5, 6, 7, 8, 'misc'],
305
+ 'in_num_ring_with_size7': [0, 1, 2, 3, 4, 5, 6, 7, 8, 'misc'],
306
+ 'in_num_ring_with_size8': [0, 1, 2, 3, 4, 5, 6, 7, 8, 'misc'],
307
+ }
308
+ bond_vocab_dict = {
309
+ "bond_dir": rdchem_enum_to_list(rdchem.BondDir.values),
310
+ "bond_type": rdchem_enum_to_list(rdchem.BondType.values),
311
+ "is_in_ring": [0, 1],
312
+
313
+ 'bond_stereo': rdchem_enum_to_list(rdchem.BondStereo.values),
314
+ 'is_conjugated': [0, 1],
315
+ }
316
+ # float features
317
+ atom_float_names = ["van_der_waals_radis", "partial_charge", 'mass']
318
+ # bond_float_feats= ["bond_length", "bond_angle"] # optional
319
+
320
+ ### functional groups
321
+ day_light_fg_smarts_list = DAY_LIGHT_FG_SMARTS_LIST
322
+ day_light_fg_mo_list = [Chem.MolFromSmarts(smarts) for smarts in day_light_fg_smarts_list]
323
+
324
+ morgan_fp_N = 200
325
+ morgan2048_fp_N = 2048
326
+ maccs_fp_N = 167
327
+
328
+ period_table = Chem.GetPeriodicTable()
329
+
330
+ ### atom
331
+
332
+ @staticmethod
333
+ def get_atom_value(atom, name):
334
+ """get atom values"""
335
+ if name == 'atomic_num':
336
+ return atom.GetAtomicNum()
337
+ elif name == 'chiral_tag':
338
+ return atom.GetChiralTag()
339
+ elif name == 'degree':
340
+ return atom.GetDegree()
341
+ elif name == 'explicit_valence':
342
+ return atom.GetExplicitValence()
343
+ elif name == 'formal_charge':
344
+ return atom.GetFormalCharge()
345
+ elif name == 'hybridization':
346
+ return atom.GetHybridization()
347
+ elif name == 'implicit_valence':
348
+ return atom.GetImplicitValence()
349
+ elif name == 'is_aromatic':
350
+ return int(atom.GetIsAromatic())
351
+ elif name == 'mass':
352
+ return int(atom.GetMass())
353
+ elif name == 'total_numHs':
354
+ return atom.GetTotalNumHs()
355
+ elif name == 'num_radical_e':
356
+ return atom.GetNumRadicalElectrons()
357
+ elif name == 'atom_is_in_ring':
358
+ return int(atom.IsInRing())
359
+ elif name == 'valence_out_shell':
360
+ return CompoundKit.period_table.GetNOuterElecs(atom.GetAtomicNum())
361
+ else:
362
+ raise ValueError(name)
363
+
364
+ @staticmethod
365
+ def get_atom_feature_id(atom, name):
366
+ """get atom features id"""
367
+ assert name in CompoundKit.atom_vocab_dict, "%s not found in atom_vocab_dict" % name
368
+ return safe_index(CompoundKit.atom_vocab_dict[name], CompoundKit.get_atom_value(atom, name))
369
+
370
+ @staticmethod
371
+ def get_atom_feature_size(name):
372
+ """get atom features size"""
373
+ assert name in CompoundKit.atom_vocab_dict, "%s not found in atom_vocab_dict" % name
374
+ return len(CompoundKit.atom_vocab_dict[name])
375
+
376
+ ### bond
377
+
378
+ @staticmethod
379
+ def get_bond_value(bond, name):
380
+ """get bond values"""
381
+ if name == 'bond_dir':
382
+ return bond.GetBondDir()
383
+ elif name == 'bond_type':
384
+ return bond.GetBondType()
385
+ elif name == 'is_in_ring':
386
+ return int(bond.IsInRing())
387
+ elif name == 'is_conjugated':
388
+ return int(bond.GetIsConjugated())
389
+ elif name == 'bond_stereo':
390
+ return bond.GetStereo()
391
+ else:
392
+ raise ValueError(name)
393
+
394
+ @staticmethod
395
+ def get_bond_feature_id(bond, name):
396
+ """get bond features id"""
397
+ assert name in CompoundKit.bond_vocab_dict, "%s not found in bond_vocab_dict" % name
398
+ return safe_index(CompoundKit.bond_vocab_dict[name], CompoundKit.get_bond_value(bond, name))
399
+
400
+ @staticmethod
401
+ def get_bond_feature_size(name):
402
+ """get bond features size"""
403
+ assert name in CompoundKit.bond_vocab_dict, "%s not found in bond_vocab_dict" % name
404
+ return len(CompoundKit.bond_vocab_dict[name])
405
+
406
+ ### fingerprint
407
+
408
+ @staticmethod
409
+ def get_morgan_fingerprint(mol, radius=2):
410
+ """get morgan fingerprint"""
411
+ nBits = CompoundKit.morgan_fp_N
412
+ mfp = AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=nBits)
413
+ return [int(b) for b in mfp.ToBitString()]
414
+
415
+ @staticmethod
416
+ def get_morgan2048_fingerprint(mol, radius=2):
417
+ """get morgan2048 fingerprint"""
418
+ nBits = CompoundKit.morgan2048_fp_N
419
+ mfp = AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=nBits)
420
+ return [int(b) for b in mfp.ToBitString()]
421
+
422
+ @staticmethod
423
+ def get_maccs_fingerprint(mol):
424
+ """get maccs fingerprint"""
425
+ fp = AllChem.GetMACCSKeysFingerprint(mol)
426
+ return [int(b) for b in fp.ToBitString()]
427
+
428
+ ### functional groups
429
+
430
+ @staticmethod
431
+ def get_daylight_functional_group_counts(mol):
432
+ """get daylight functional group counts"""
433
+ fg_counts = []
434
+ for fg_mol in CompoundKit.day_light_fg_mo_list:
435
+ sub_structs = Chem.Mol.GetSubstructMatches(mol, fg_mol, uniquify=True)
436
+ fg_counts.append(len(sub_structs))
437
+ return fg_counts
438
+
439
+ @staticmethod
440
+ def get_ring_size(mol):
441
+ """return (N,6) list"""
442
+ rings = mol.GetRingInfo()
443
+ rings_info = []
444
+ for r in rings.AtomRings():
445
+ rings_info.append(r)
446
+ ring_list = []
447
+ for atom in mol.GetAtoms():
448
+ atom_result = []
449
+ for ringsize in range(3, 9):
450
+ num_of_ring_at_ringsize = 0
451
+ for r in rings_info:
452
+ if len(r) == ringsize and atom.GetIdx() in r:
453
+ num_of_ring_at_ringsize += 1
454
+ if num_of_ring_at_ringsize > 8:
455
+ num_of_ring_at_ringsize = 9
456
+ atom_result.append(num_of_ring_at_ringsize)
457
+
458
+ ring_list.append(atom_result)
459
+ return ring_list
460
+
461
+ @staticmethod
462
+ def atom_to_feat_vector(atom):
463
+ """ tbd """
464
+ atom_names = {
465
+ "atomic_num": safe_index(CompoundKit.atom_vocab_dict["atomic_num"], atom.GetAtomicNum()),
466
+ "chiral_tag": safe_index(CompoundKit.atom_vocab_dict["chiral_tag"], atom.GetChiralTag()),
467
+ "degree": safe_index(CompoundKit.atom_vocab_dict["degree"], atom.GetTotalDegree()),
468
+ "explicit_valence": safe_index(CompoundKit.atom_vocab_dict["explicit_valence"], atom.GetExplicitValence()),
469
+ "formal_charge": safe_index(CompoundKit.atom_vocab_dict["formal_charge"], atom.GetFormalCharge()),
470
+ "hybridization": safe_index(CompoundKit.atom_vocab_dict["hybridization"], atom.GetHybridization()),
471
+ "implicit_valence": safe_index(CompoundKit.atom_vocab_dict["implicit_valence"], atom.GetImplicitValence()),
472
+ "is_aromatic": safe_index(CompoundKit.atom_vocab_dict["is_aromatic"], int(atom.GetIsAromatic())),
473
+ "total_numHs": safe_index(CompoundKit.atom_vocab_dict["total_numHs"], atom.GetTotalNumHs()),
474
+ 'num_radical_e': safe_index(CompoundKit.atom_vocab_dict['num_radical_e'], atom.GetNumRadicalElectrons()),
475
+ 'atom_is_in_ring': safe_index(CompoundKit.atom_vocab_dict['atom_is_in_ring'], int(atom.IsInRing())),
476
+ 'valence_out_shell': safe_index(CompoundKit.atom_vocab_dict['valence_out_shell'],
477
+ CompoundKit.period_table.GetNOuterElecs(atom.GetAtomicNum())),
478
+ 'van_der_waals_radis': CompoundKit.period_table.GetRvdw(atom.GetAtomicNum()),
479
+ 'partial_charge': CompoundKit.check_partial_charge(atom),
480
+ 'mass': atom.GetMass(),
481
+ }
482
+ return atom_names
483
+
484
+ @staticmethod
485
+ def get_atom_names(mol):
486
+ """get atom name list
487
+ TODO: to be remove in the future
488
+ """
489
+ atom_features_dicts = []
490
+ Chem.rdPartialCharges.ComputeGasteigerCharges(mol)
491
+ for i, atom in enumerate(mol.GetAtoms()):
492
+ atom_features_dicts.append(CompoundKit.atom_to_feat_vector(atom))
493
+
494
+ ring_list = CompoundKit.get_ring_size(mol)
495
+ for i, atom in enumerate(mol.GetAtoms()):
496
+ atom_features_dicts[i]['in_num_ring_with_size3'] = safe_index(
497
+ CompoundKit.atom_vocab_dict['in_num_ring_with_size3'], ring_list[i][0])
498
+ atom_features_dicts[i]['in_num_ring_with_size4'] = safe_index(
499
+ CompoundKit.atom_vocab_dict['in_num_ring_with_size4'], ring_list[i][1])
500
+ atom_features_dicts[i]['in_num_ring_with_size5'] = safe_index(
501
+ CompoundKit.atom_vocab_dict['in_num_ring_with_size5'], ring_list[i][2])
502
+ atom_features_dicts[i]['in_num_ring_with_size6'] = safe_index(
503
+ CompoundKit.atom_vocab_dict['in_num_ring_with_size6'], ring_list[i][3])
504
+ atom_features_dicts[i]['in_num_ring_with_size7'] = safe_index(
505
+ CompoundKit.atom_vocab_dict['in_num_ring_with_size7'], ring_list[i][4])
506
+ atom_features_dicts[i]['in_num_ring_with_size8'] = safe_index(
507
+ CompoundKit.atom_vocab_dict['in_num_ring_with_size8'], ring_list[i][5])
508
+
509
+ return atom_features_dicts
510
+
511
+ @staticmethod
512
+ def check_partial_charge(atom):
513
+ """tbd"""
514
+ pc = atom.GetDoubleProp('_GasteigerCharge')
515
+ if pc != pc:
516
+ # unsupported atom, replace nan with 0
517
+ pc = 0
518
+ if pc == float('inf'):
519
+ # max 4 for other atoms, set to 10 here if inf is get
520
+ pc = 10
521
+ return pc
522
+
523
+
524
+ class Compound3DKit(object):
525
+ """the 3Dkit of Compound"""
526
+
527
+ @staticmethod
528
+ def get_atom_poses(mol, conf):
529
+ """tbd"""
530
+ atom_poses = []
531
+ for i, atom in enumerate(mol.GetAtoms()):
532
+ if atom.GetAtomicNum() == 0:
533
+ return [[0.0, 0.0, 0.0]] * len(mol.GetAtoms())
534
+ pos = conf.GetAtomPosition(i)
535
+ atom_poses.append([pos.x, pos.y, pos.z])
536
+ return atom_poses
537
+
538
+ @staticmethod
539
+ def get_MMFF_atom_poses(mol, numConfs=None, return_energy=False):
540
+ """the atoms of mol will be changed in some cases."""
541
+ conf = mol.GetConformer()
542
+ atom_poses = Compound3DKit.get_atom_poses(mol, conf)
543
+ return mol,atom_poses
544
+ # try:
545
+ # new_mol = Chem.AddHs(mol)
546
+ # res = AllChem.EmbedMultipleConfs(new_mol, numConfs=numConfs)
547
+ # ### MMFF generates multiple conformations
548
+ # res = AllChem.MMFFOptimizeMoleculeConfs(new_mol)
549
+ # new_mol = Chem.RemoveHs(new_mol)
550
+ # index = np.argmin([x[1] for x in res])
551
+ # energy = res[index][1]
552
+ # conf = new_mol.GetConformer(id=int(index))
553
+ # except:
554
+ # new_mol = mol
555
+ # AllChem.Compute2DCoords(new_mol)
556
+ # energy = 0
557
+ # conf = new_mol.GetConformer()
558
+ #
559
+ # atom_poses = Compound3DKit.get_atom_poses(new_mol, conf)
560
+ # if return_energy:
561
+ # return new_mol, atom_poses, energy
562
+ # else:
563
+ # return new_mol, atom_poses
564
+
565
+ @staticmethod
566
+ def get_2d_atom_poses(mol):
567
+ """get 2d atom poses"""
568
+ AllChem.Compute2DCoords(mol)
569
+ conf = mol.GetConformer()
570
+ atom_poses = Compound3DKit.get_atom_poses(mol, conf)
571
+ return atom_poses
572
+
573
+ @staticmethod
574
+ def get_bond_lengths(edges, atom_poses):
575
+ """get bond lengths"""
576
+ bond_lengths = []
577
+ for src_node_i, tar_node_j in edges:
578
+ bond_lengths.append(np.linalg.norm(atom_poses[tar_node_j] - atom_poses[src_node_i]))
579
+ bond_lengths = np.array(bond_lengths, 'float32')
580
+ return bond_lengths
581
+
582
+ @staticmethod
583
+ def get_superedge_angles(edges, atom_poses, dir_type='HT'):
584
+ """get superedge angles"""
585
+
586
+ def _get_vec(atom_poses, edge):
587
+ return atom_poses[edge[1]] - atom_poses[edge[0]]
588
+
589
+ def _get_angle(vec1, vec2):
590
+ norm1 = np.linalg.norm(vec1)
591
+ norm2 = np.linalg.norm(vec2)
592
+ if norm1 == 0 or norm2 == 0:
593
+ return 0
594
+ vec1 = vec1 / (norm1 + 1e-5) # 1e-5: prevent numerical errors
595
+ vec2 = vec2 / (norm2 + 1e-5)
596
+ angle = np.arccos(np.dot(vec1, vec2))
597
+ return angle
598
+
599
+ E = len(edges)
600
+ edge_indices = np.arange(E)
601
+ super_edges = []
602
+ bond_angles = []
603
+ bond_angle_dirs = []
604
+ for tar_edge_i in range(E):
605
+ tar_edge = edges[tar_edge_i]
606
+ if dir_type == 'HT':
607
+ src_edge_indices = edge_indices[edges[:, 1] == tar_edge[0]]
608
+ elif dir_type == 'HH':
609
+ src_edge_indices = edge_indices[edges[:, 1] == tar_edge[1]]
610
+ else:
611
+ raise ValueError(dir_type)
612
+ for src_edge_i in src_edge_indices:
613
+ if src_edge_i == tar_edge_i:
614
+ continue
615
+ src_edge = edges[src_edge_i]
616
+ src_vec = _get_vec(atom_poses, src_edge)
617
+ tar_vec = _get_vec(atom_poses, tar_edge)
618
+ super_edges.append([src_edge_i, tar_edge_i])
619
+ angle = _get_angle(src_vec, tar_vec)
620
+ bond_angles.append(angle)
621
+ bond_angle_dirs.append(src_edge[1] == tar_edge[0]) # H -> H or H -> T
622
+
623
+ if len(super_edges) == 0:
624
+ super_edges = np.zeros([0, 2], 'int64')
625
+ bond_angles = np.zeros([0, ], 'float32')
626
+ else:
627
+ super_edges = np.array(super_edges, 'int64')
628
+ bond_angles = np.array(bond_angles, 'float32')
629
+ return super_edges, bond_angles, bond_angle_dirs
630
+
631
+
632
+ def new_smiles_to_graph_data(smiles, **kwargs):
633
+ """
634
+ Convert smiles to graph data.
635
+ """
636
+ mol = AllChem.MolFromSmiles(smiles)
637
+ if mol is None:
638
+ return None
639
+ data = new_mol_to_graph_data(mol)
640
+ return data
641
+
642
+
643
+ def new_mol_to_graph_data(mol):
644
+ """
645
+ mol_to_graph_data
646
+ Args:
647
+ atom_features: Atom features.
648
+ edge_features: Edge features.
649
+ morgan_fingerprint: Morgan fingerprint.
650
+ functional_groups: Functional groups.
651
+ """
652
+ if len(mol.GetAtoms()) == 0:
653
+ return None
654
+
655
+ atom_id_names = list(CompoundKit.atom_vocab_dict.keys()) + CompoundKit.atom_float_names
656
+ bond_id_names = list(CompoundKit.bond_vocab_dict.keys())
657
+
658
+ data = {}
659
+
660
+ ### atom features
661
+ data = {name: [] for name in atom_id_names}
662
+
663
+ raw_atom_feat_dicts = CompoundKit.get_atom_names(mol)
664
+ for atom_feat in raw_atom_feat_dicts:
665
+ for name in atom_id_names:
666
+ data[name].append(atom_feat[name])
667
+
668
+ ### bond and bond features
669
+ for name in bond_id_names:
670
+ data[name] = []
671
+ data['edges'] = []
672
+
673
+ for bond in mol.GetBonds():
674
+ i = bond.GetBeginAtomIdx()
675
+ j = bond.GetEndAtomIdx()
676
+ # i->j and j->i
677
+ data['edges'] += [(i, j), (j, i)]
678
+ for name in bond_id_names:
679
+ bond_feature_id = CompoundKit.get_bond_feature_id(bond, name)
680
+ data[name] += [bond_feature_id] * 2
681
+
682
+ #### self loop
683
+ N = len(data[atom_id_names[0]])
684
+ for i in range(N):
685
+ data['edges'] += [(i, i)]
686
+ for name in bond_id_names:
687
+ bond_feature_id = get_bond_feature_dims([name])[0] - 1 # self loop: value = len - 1
688
+ data[name] += [bond_feature_id] * N
689
+
690
+ ### make ndarray and check length
691
+ for name in list(CompoundKit.atom_vocab_dict.keys()):
692
+ data[name] = np.array(data[name], 'int64')
693
+ for name in CompoundKit.atom_float_names:
694
+ data[name] = np.array(data[name], 'float32')
695
+ for name in bond_id_names:
696
+ data[name] = np.array(data[name], 'int64')
697
+ data['edges'] = np.array(data['edges'], 'int64')
698
+
699
+ ### morgan fingerprint
700
+ data['morgan_fp'] = np.array(CompoundKit.get_morgan_fingerprint(mol), 'int64')
701
+ # data['morgan2048_fp'] = np.array(CompoundKit.get_morgan2048_fingerprint(mol), 'int64')
702
+ data['maccs_fp'] = np.array(CompoundKit.get_maccs_fingerprint(mol), 'int64')
703
+ data['daylight_fg_counts'] = np.array(CompoundKit.get_daylight_functional_group_counts(mol), 'int64')
704
+ return data
705
+
706
+
707
+ def mol_to_graph_data(mol):
708
+ """
709
+ mol_to_graph_data
710
+ Args:
711
+ atom_features: Atom features.
712
+ edge_features: Edge features.
713
+ morgan_fingerprint: Morgan fingerprint.
714
+ functional_groups: Functional groups.
715
+ """
716
+ if len(mol.GetAtoms()) == 0:
717
+ return None
718
+
719
+ atom_id_names = [
720
+ "atomic_num", "chiral_tag", "degree", "explicit_valence",
721
+ "formal_charge", "hybridization", "implicit_valence",
722
+ "is_aromatic", "total_numHs",
723
+ ]
724
+ bond_id_names = [
725
+ "bond_dir", "bond_type", "is_in_ring",
726
+ ]
727
+
728
+ data = {}
729
+ for name in atom_id_names:
730
+ data[name] = []
731
+ data['mass'] = []
732
+ for name in bond_id_names:
733
+ data[name] = []
734
+ data['edges'] = []
735
+
736
+ ### atom features
737
+ for i, atom in enumerate(mol.GetAtoms()):
738
+ if atom.GetAtomicNum() == 0:
739
+ return None
740
+ for name in atom_id_names:
741
+ data[name].append(CompoundKit.get_atom_feature_id(atom, name) + 1) # 0: OOV
742
+ data['mass'].append(CompoundKit.get_atom_value(atom, 'mass') * 0.01)
743
+
744
+ ### bond features
745
+ for bond in mol.GetBonds():
746
+ i = bond.GetBeginAtomIdx()
747
+ j = bond.GetEndAtomIdx()
748
+ # i->j and j->i
749
+ data['edges'] += [(i, j), (j, i)]
750
+ for name in bond_id_names:
751
+ bond_feature_id = CompoundKit.get_bond_feature_id(bond, name) + 1 # 0: OOV
752
+ data[name] += [bond_feature_id] * 2
753
+
754
+ ### self loop (+2)
755
+ N = len(data[atom_id_names[0]])
756
+ for i in range(N):
757
+ data['edges'] += [(i, i)]
758
+ for name in bond_id_names:
759
+ bond_feature_id = CompoundKit.get_bond_feature_size(name) + 2 # N + 2: self loop
760
+ data[name] += [bond_feature_id] * N
761
+
762
+ ### check whether edge exists
763
+ if len(data['edges']) == 0: # mol has no bonds
764
+ for name in bond_id_names:
765
+ data[name] = np.zeros((0,), dtype="int64")
766
+ data['edges'] = np.zeros((0, 2), dtype="int64")
767
+
768
+ ### make ndarray and check length
769
+ for name in atom_id_names:
770
+ data[name] = np.array(data[name], 'int64')
771
+ data['mass'] = np.array(data['mass'], 'float32')
772
+ for name in bond_id_names:
773
+ data[name] = np.array(data[name], 'int64')
774
+ data['edges'] = np.array(data['edges'], 'int64')
775
+
776
+ ### morgan fingerprint
777
+ data['morgan_fp'] = np.array(CompoundKit.get_morgan_fingerprint(mol), 'int64')
778
+ # data['morgan2048_fp'] = np.array(CompoundKit.get_morgan2048_fingerprint(mol), 'int64')
779
+ data['maccs_fp'] = np.array(CompoundKit.get_maccs_fingerprint(mol), 'int64')
780
+ data['daylight_fg_counts'] = np.array(CompoundKit.get_daylight_functional_group_counts(mol), 'int64')
781
+ return data
782
+
783
+
784
+ def mol_to_geognn_graph_data(mol, atom_poses, dir_type):
785
+ """
786
+ mol: rdkit molecule
787
+ dir_type: direction type for bond_angle grpah
788
+ """
789
+ if len(mol.GetAtoms()) == 0:
790
+ return None
791
+
792
+ data = mol_to_graph_data(mol)
793
+
794
+ data['atom_pos'] = np.array(atom_poses, 'float32')
795
+ data['bond_length'] = Compound3DKit.get_bond_lengths(data['edges'], data['atom_pos'])
796
+ BondAngleGraph_edges, bond_angles, bond_angle_dirs = \
797
+ Compound3DKit.get_superedge_angles(data['edges'], data['atom_pos'])
798
+ data['BondAngleGraph_edges'] = BondAngleGraph_edges
799
+ data['bond_angle'] = np.array(bond_angles, 'float32')
800
+ return data
801
+
802
+
803
+ def mol_to_geognn_graph_data_MMFF3d(mol):
804
+ """tbd"""
805
+ if len(mol.GetAtoms()) <= 400:
806
+ mol, atom_poses = Compound3DKit.get_MMFF_atom_poses(mol, numConfs=10)
807
+ else:
808
+ atom_poses = Compound3DKit.get_2d_atom_poses(mol)
809
+ return mol_to_geognn_graph_data(mol, atom_poses, dir_type='HT')
810
+
811
+
812
+ def mol_to_geognn_graph_data_raw3d(mol):
813
+ """tbd"""
814
+ atom_poses = Compound3DKit.get_atom_poses(mol, mol.GetConformer())
815
+ return mol_to_geognn_graph_data(mol, atom_poses, dir_type='HT')
816
+
817
+ def obtain_3D_mol(smiles,name):
818
+ mol = AllChem.MolFromSmiles(smiles)
819
+ new_mol = Chem.AddHs(mol)
820
+ res = AllChem.EmbedMultipleConfs(new_mol)
821
+ ### MMFF generates multiple conformations
822
+ res = AllChem.MMFFOptimizeMoleculeConfs(new_mol)
823
+ new_mol = Chem.RemoveHs(new_mol)
824
+ Chem.MolToMolFile(new_mol, name+'.mol')
825
+ return new_mol
826
+
827
+ os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
828
+ warnings.filterwarnings('ignore')
829
+
830
+ #============Parameter setting===============
831
+ MODEL = 'Test' #['Train','Test','Test_other_method','Test_enantiomer','Test_excel']
832
+ test_mode='fixed' #fixed or random or enantiomer(extract enantimoers)
833
+ transfer_target='All_column' #trail name
834
+ Use_geometry_enhanced=True #default:True
835
+ Use_column_info=True #default: True
836
+
837
+ atom_id_names = [
838
+ "atomic_num", "chiral_tag", "degree", "explicit_valence",
839
+ "formal_charge", "hybridization", "implicit_valence",
840
+ "is_aromatic", "total_numHs",
841
+ ]
842
+ bond_id_names = [
843
+ "bond_dir", "bond_type", "is_in_ring"]
844
+
845
+ if Use_geometry_enhanced==True:
846
+ bond_float_names = ["bond_length",'prop']
847
+
848
+ if Use_geometry_enhanced==False:
849
+ bond_float_names=['prop']
850
+
851
+ bond_angle_float_names = ['bond_angle', 'TPSA', 'RASA', 'RPSA', 'MDEC', 'MATS']
852
+
853
+ column_specify={'ADH':[1,5,0,0],'ODH':[1,5,0,1],'IC':[0,5,1,2],'IA':[0,5,1,3],'OJH':[1,5,0,4],
854
+ 'ASH':[1,5,0,5],'IC3':[0,3,1,6],'IE':[0,5,1,7],'ID':[0,5,1,8],'OD3':[1,3,0,9],
855
+ 'IB':[0,5,1,10],'AD':[1,10,0,11],'AD3':[1,3,0,12],'IF':[0,5,1,13],'OD':[1,10,0,14],
856
+ 'AS':[1,10,0,15],'OJ3':[1,3,0,16],'IG':[0,5,1,17],'AZ':[1,10,0,18],'IAH':[0,5,1,19],
857
+ 'OJ':[1,10,0,20],'ICH':[0,5,1,21],'OZ3':[1,3,0,22],'IF3':[0,3,1,23],'IAU':[0,1.6,1,24]}
858
+ column_smile=['O=C(OC[C@@H](O1)[C@@H](OC)[C@H](OC(NC2=CC(C)=CC(C)=C2)=O)[C@@H](OC(NC3=CC(C)=CC(C)=C3)=O)[C@H]1OC)NC4=CC(C)=CC(C)=C4',
859
+ 'O=C(OC[C@@H](O1)[C@@H](OC)[C@H](OC(NC2=CC(C)=CC(C)=C2)=O)[C@@H](OC(NC3=CC(C)=CC(C)=C3)=O)[C@@H]1OC)NC4=CC(C)=CC(C)=C4',
860
+ 'O=C(OC[C@@H](O1)[C@@H](OC)[C@H](OC(NC2=CC(Cl)=CC(Cl)=C2)=O)[C@@H](OC(NC3=CC(Cl)=CC(Cl)=C3)=O)[C@@H]1OC)NC4=CC(Cl)=CC(Cl)=C4',
861
+ 'O=C(OC[C@@H](O1)[C@@H](OC)[C@H](OC(NC2=CC(C)=CC(C)=C2)=O)[C@@H](OC(NC3=CC(C)=CC(C)=C3)=O)[C@H]1OC)NC4=CC(C)=CC(C)=C4',
862
+ 'O=C(OC[C@@H](O1)[C@@H](OC)[C@H](OC(C2=CC=C(C)C=C2)=O)[C@@H](OC(C3=CC=C(C)C=C3)=O)[C@@H]1OC)C4=CC=C(C)C=C4',
863
+ 'O=C(OC[C@@H](O1)[C@@H](OC)[C@H](OC(N[C@@H](C)C2=CC=CC=C2)=O)[C@@H](OC(N[C@@H](C)C3=CC=CC=C3)=O)[C@H]1OC)N[C@@H](C)C4=CC=CC=C4',
864
+ 'O=C(OC[C@@H](O1)[C@@H](OC)[C@H](OC(NC2=CC(Cl)=CC(Cl)=C2)=O)[C@@H](OC(NC3=CC(Cl)=CC(Cl)=C3)=O)[C@@H]1OC)NC4=CC(Cl)=CC(Cl)=C4',
865
+ 'O=C(OC[C@@H](O1)[C@@H](OC)[C@H](OC(NC2=CC(Cl)=CC(Cl)=C2)=O)[C@@H](OC(NC3=CC(Cl)=CC(Cl)=C3)=O)[C@H]1OC)NC4=CC(Cl)=CC(Cl)=C4',
866
+ 'O=C(OC[C@@H](O1)[C@@H](OC)[C@H](OC(NC2=CC=CC(Cl)=C2)=O)[C@@H](OC(NC3=CC=CC(Cl)=C3)=O)[C@H]1OC)NC4=CC=CC(Cl)=C4',
867
+ 'O=C(OC[C@@H](O1)[C@@H](OC)[C@H](OC(NC2=CC(C)=CC(C)=C2)=O)[C@@H](OC(NC3=CC(C)=CC(C)=C3)=O)[C@@H]1OC)NC4=CC(C)=CC(C)=C4',
868
+ 'O=C(OC[C@@H](O1)[C@@H](OC)[C@H](OC(NC2=CC(C)=CC(C)=C2)=O)[C@@H](OC(NC3=CC(C)=CC(C)=C3)=O)[C@@H]1OC)NC4=CC(C)=CC(C)=C4',
869
+ 'O=C(OC[C@@H](O1)[C@@H](OC)[C@H](OC(NC2=CC(C)=CC(C)=C2)=O)[C@@H](OC(NC3=CC(C)=CC(C)=C3)=O)[C@H]1OC)NC4=CC(C)=CC(C)=C4',
870
+ 'O=C(OC[C@@H](O1)[C@@H](OC)[C@H](OC(NC2=CC(C)=CC(C)=C2)=O)[C@@H](OC(NC3=CC(C)=CC(C)=C3)=O)[C@H]1OC)NC4=CC(C)=CC(C)=C4',
871
+ 'O=C(OC[C@@H](O1)[C@@H](OC)[C@H](OC(NC2=CC=C(C)C(Cl)=C2)=O)[C@@H](OC(NC3=CC=C(C)C(Cl)=C3)=O)[C@H]1OC)NC4=CC=C(C)C(Cl)=C4',
872
+ 'O=C(OC[C@@H](O1)[C@@H](OC)[C@H](OC(NC2=CC(C)=CC(C)=C2)=O)[C@@H](OC(NC3=CC(C)=CC(C)=C3)=O)[C@@H]1OC)NC4=CC(C)=CC(C)=C4',
873
+ 'O=C(OC[C@@H](O1)[C@@H](OC)[C@H](OC(N[C@@H](C)C2=CC=CC=C2)=O)[C@@H](OC(N[C@@H](C)C3=CC=CC=C3)=O)[C@H]1OC)N[C@@H](C)C4=CC=CC=C4',
874
+ 'O=C(OC[C@@H](O1)[C@@H](OC)[C@H](OC(C2=CC=C(C)C=C2)=O)[C@@H](OC(C3=CC=C(C)C=C3)=O)[C@@H]1OC)C4=CC=C(C)C=C4',
875
+ 'O=C(OC[C@@H](O1)[C@@H](OC)[C@H](OC(NC2=CC(C)=CC(Cl)=C2)=O)[C@@H](OC(NC3=CC(C)=CC(Cl)=C3)=O)[C@H]1OC)NC4=CC(C)=CC(Cl)=C4',
876
+ 'O=C(OC[C@@H](O1)[C@@H](OC)[C@H](OC(NC2=CC=C(C)C(Cl)=C2)=O)[C@@H](OC(NC3=CC=C(C)C(Cl)=C3)=O)[C@H]1OC)NC4=CC=C(C)C(Cl)=C4',
877
+ 'O=C(OC[C@@H](O1)[C@@H](OC)[C@H](OC(NC2=CC(C)=CC(C)=C2)=O)[C@@H](OC(NC3=CC(C)=CC(C)=C3)=O)[C@H]1OC)NC4=CC(C)=CC(C)=C4',
878
+ 'O=C(OC[C@@H](O1)[C@@H](OC)[C@H](OC(C2=CC=C(C)C=C2)=O)[C@@H](OC(C3=CC=C(C)C=C3)=O)[C@@H]1OC)C4=CC=C(C)C=C4',
879
+ 'O=C(OC[C@@H](O1)[C@@H](OC)[C@H](OC(NC2=CC(Cl)=CC(Cl)=C2)=O)[C@@H](OC(NC3=CC(Cl)=CC(Cl)=C3)=O)[C@@H]1OC)NC4=CC(Cl)=CC(Cl)=C4',
880
+ 'O=C(OC[C@@H](O1)[C@@H](OC)[C@H](OC(NC2=CC=C(C)C(Cl)=C2)=O)[C@@H](OC(NC3=CC=C(C)C(Cl)=C3)=O)[C@@H]1OC)NC4=CC=C(C)C(Cl)=C4',
881
+ 'O=C(OC[C@@H](O1)[C@@H](OC)[C@H](OC(NC2=CC=C(C)C(Cl)=C2)=O)[C@@H](OC(NC3=CC=C(C)C(Cl)=C3)=O)[C@H]1OC)NC4=CC=C(C)C(Cl)=C4',
882
+ 'O=C(OC[C@@H](O1)[C@@H](OC)[C@H](OC(NC2=CC(C)=CC(C)=C2)=O)[C@@H](OC(NC3=CC(C)=CC(C)=C3)=O)[C@H]1OC)NC4=CC(C)=CC(C)=C4']
883
+ column_name=['ADH','ODH','IC','IA','OJH','ASH','IC3','IE','ID','OD3', 'IB','AD','AD3',
884
+ 'IF','OD','AS','OJ3','IG','AZ','IAH','OJ','ICH','OZ3','IF3','IAU']
885
+ full_atom_feature_dims = get_atom_feature_dims(atom_id_names)
886
+ full_bond_feature_dims = get_bond_feature_dims(bond_id_names)
887
+
888
+
889
+ if Use_column_info==True:
890
+ bond_id_names.extend(['coated', 'immobilized'])
891
+ bond_float_names.extend(['diameter'])
892
+ if Use_geometry_enhanced==True:
893
+ bond_angle_float_names.extend(['column_TPSA', 'column_TPSA', 'column_TPSA', 'column_MDEC', 'column_MATS'])
894
+ else:
895
+ bond_float_names.extend(['column_TPSA', 'column_TPSA', 'column_TPSA', 'column_MDEC', 'column_MATS'])
896
+ full_bond_feature_dims.extend([2,2])
897
+
898
+ calc = Calculator(descriptors, ignore_3D=False)
899
+
900
+
901
+ class AtomEncoder(torch.nn.Module):
902
+
903
+ def __init__(self, emb_dim):
904
+ super(AtomEncoder, self).__init__()
905
+
906
+ self.atom_embedding_list = torch.nn.ModuleList()
907
+
908
+ for i, dim in enumerate(full_atom_feature_dims):
909
+ emb = torch.nn.Embedding(dim + 5, emb_dim) # 不同维度的属性用不同的Embedding方法
910
+ torch.nn.init.xavier_uniform_(emb.weight.data)
911
+ self.atom_embedding_list.append(emb)
912
+
913
+ def forward(self, x):
914
+ x_embedding = 0
915
+ for i in range(x.shape[1]):
916
+ x_embedding += self.atom_embedding_list[i](x[:, i])
917
+
918
+ return x_embedding
919
+
920
+ class BondEncoder(torch.nn.Module):
921
+
922
+ def __init__(self, emb_dim):
923
+ super(BondEncoder, self).__init__()
924
+
925
+ self.bond_embedding_list = torch.nn.ModuleList()
926
+
927
+ for i, dim in enumerate(full_bond_feature_dims):
928
+ emb = torch.nn.Embedding(dim + 5, emb_dim)
929
+ torch.nn.init.xavier_uniform_(emb.weight.data)
930
+ self.bond_embedding_list.append(emb)
931
+
932
+ def forward(self, edge_attr):
933
+ bond_embedding = 0
934
+ for i in range(edge_attr.shape[1]):
935
+ bond_embedding += self.bond_embedding_list[i](edge_attr[:, i])
936
+
937
+ return bond_embedding
938
+
939
+ class RBF(torch.nn.Module):
940
+ """
941
+ Radial Basis Function
942
+ """
943
+
944
+ def __init__(self, centers, gamma, dtype='float32'):
945
+ super(RBF, self).__init__()
946
+ self.centers = centers.reshape([1, -1])
947
+ self.gamma = gamma
948
+
949
+ def forward(self, x):
950
+ """
951
+ Args:
952
+ x(tensor): (-1, 1).
953
+ Returns:
954
+ y(tensor): (-1, n_centers)
955
+ """
956
+ x = x.reshape([-1, 1])
957
+ return torch.exp(-self.gamma * torch.square(x - self.centers))
958
+
959
+ class BondFloatRBF(torch.nn.Module):
960
+ """
961
+ Bond Float Encoder using Radial Basis Functions
962
+ """
963
+
964
+ def __init__(self, bond_float_names, embed_dim, rbf_params=None):
965
+ super(BondFloatRBF, self).__init__()
966
+ self.bond_float_names = bond_float_names
967
+
968
+ if rbf_params is None:
969
+ self.rbf_params = {
970
+ 'bond_length': (nn.Parameter(torch.arange(0, 2, 0.1)), nn.Parameter(torch.Tensor([10.0]))),
971
+ # (centers, gamma)
972
+ 'prop': (nn.Parameter(torch.arange(0, 1, 0.05)), nn.Parameter(torch.Tensor([1.0]))),
973
+ 'diameter': (nn.Parameter(torch.arange(3, 12, 0.3)), nn.Parameter(torch.Tensor([1.0]))),
974
+ ##=========Only for pure GNN===============
975
+ 'column_TPSA': (nn.Parameter(torch.arange(0, 1, 0.05).to(torch.float32)), nn.Parameter(torch.Tensor([1.0]))),
976
+ 'column_RASA': (nn.Parameter(torch.arange(0, 1, 0.05)), nn.Parameter(torch.Tensor([1.0]))),
977
+ 'column_RPSA': (nn.Parameter(torch.arange(0, 1, 0.05)), nn.Parameter(torch.Tensor([1.0]))),
978
+ 'column_MDEC': (nn.Parameter(torch.arange(0, 10, 0.5)), nn.Parameter(torch.Tensor([2.0]))),
979
+ 'column_MATS': (nn.Parameter(torch.arange(0, 1, 0.05)), nn.Parameter(torch.Tensor([1.0]))),
980
+ }
981
+ else:
982
+ self.rbf_params = rbf_params
983
+
984
+ self.linear_list = torch.nn.ModuleList()
985
+ self.rbf_list = torch.nn.ModuleList()
986
+ for name in self.bond_float_names:
987
+ centers, gamma = self.rbf_params[name]
988
+ rbf = RBF(centers.to(device), gamma.to(device))
989
+ self.rbf_list.append(rbf)
990
+ linear = torch.nn.Linear(len(centers), embed_dim).cuda()
991
+ self.linear_list.append(linear)
992
+
993
+ def forward(self, bond_float_features):
994
+ """
995
+ Args:
996
+ bond_float_features(dict of tensor): bond float features.
997
+ """
998
+ out_embed = 0
999
+ for i, name in enumerate(self.bond_float_names):
1000
+ x = bond_float_features[:, i].reshape(-1, 1)
1001
+ rbf_x = self.rbf_list[i](x)
1002
+ out_embed += self.linear_list[i](rbf_x)
1003
+ return out_embed
1004
+
1005
+ class BondAngleFloatRBF(torch.nn.Module):
1006
+ """
1007
+ Bond Angle Float Encoder using Radial Basis Functions
1008
+ """
1009
+
1010
+ def __init__(self, bond_angle_float_names, embed_dim, rbf_params=None):
1011
+ super(BondAngleFloatRBF, self).__init__()
1012
+ self.bond_angle_float_names = bond_angle_float_names
1013
+
1014
+ if rbf_params is None:
1015
+ self.rbf_params = {
1016
+ 'bond_angle': (nn.Parameter(torch.arange(0, torch.pi, 0.1)), nn.Parameter(torch.Tensor([10.0]))),
1017
+ }
1018
+ else:
1019
+ self.rbf_params = rbf_params
1020
+
1021
+ self.linear_list = torch.nn.ModuleList()
1022
+ self.rbf_list = torch.nn.ModuleList()
1023
+ for name in self.bond_angle_float_names:
1024
+ if name == 'bond_angle':
1025
+ centers, gamma = self.rbf_params[name]
1026
+ rbf = RBF(centers.to(device), gamma.to(device))
1027
+ self.rbf_list.append(rbf)
1028
+ linear = nn.Linear(len(centers), embed_dim)
1029
+ self.linear_list.append(linear)
1030
+ else:
1031
+ linear = nn.Linear(len(self.bond_angle_float_names) - 1, embed_dim)
1032
+ self.linear_list.append(linear)
1033
+ break
1034
+
1035
+ def forward(self, bond_angle_float_features):
1036
+ """
1037
+ Args:
1038
+ bond_angle_float_features(dict of tensor): bond angle float features.
1039
+ """
1040
+ out_embed = 0
1041
+ for i, name in enumerate(self.bond_angle_float_names):
1042
+ if name == 'bond_angle':
1043
+ x = bond_angle_float_features[:, i].reshape(-1, 1)
1044
+ rbf_x = self.rbf_list[i](x)
1045
+ out_embed += self.linear_list[i](rbf_x)
1046
+ else:
1047
+ x = bond_angle_float_features[:, 1:]
1048
+ out_embed += self.linear_list[i](x)
1049
+ break
1050
+ return out_embed
1051
+
1052
+ class GINConv(MessagePassing):
1053
+ def __init__(self, emb_dim):
1054
+ '''
1055
+ emb_dim (int): node embedding dimensionality
1056
+ '''
1057
+
1058
+ super(GINConv, self).__init__(aggr="add")
1059
+
1060
+ self.mlp = nn.Sequential(nn.Linear(emb_dim, emb_dim), nn.BatchNorm1d(emb_dim), nn.ReLU(),
1061
+ nn.Linear(emb_dim, emb_dim))
1062
+ self.eps = nn.Parameter(torch.Tensor([0]))
1063
+
1064
+ def forward(self, x, edge_index, edge_attr):
1065
+ edge_embedding = edge_attr
1066
+ out = self.mlp((1 + self.eps) * x + self.propagate(edge_index, x=x, edge_attr=edge_embedding))
1067
+ return out
1068
+
1069
+ def message(self, x_j, edge_attr):
1070
+ return F.relu(x_j + edge_attr)
1071
+
1072
+ def update(self, aggr_out):
1073
+ return aggr_out
1074
+
1075
+ # GNN to generate node embedding
1076
+ class GINNodeEmbedding(torch.nn.Module):
1077
+ """
1078
+ Output:
1079
+ node representations
1080
+ """
1081
+
1082
+ def __init__(self, num_layers, emb_dim, drop_ratio=0.5, JK="last", residual=False):
1083
+ """GIN Node Embedding Module
1084
+ 采用多层GINConv实现图上结点的嵌入。
1085
+ """
1086
+
1087
+ super(GINNodeEmbedding, self).__init__()
1088
+ self.num_layers = num_layers
1089
+ self.drop_ratio = drop_ratio
1090
+ self.JK = JK
1091
+ # add residual connection or not
1092
+ self.residual = residual
1093
+
1094
+ if self.num_layers < 2:
1095
+ raise ValueError("Number of GNN layers must be greater than 1.")
1096
+
1097
+ self.atom_encoder = AtomEncoder(emb_dim)
1098
+ self.bond_encoder=BondEncoder(emb_dim)
1099
+ self.bond_float_encoder=BondFloatRBF(bond_float_names,emb_dim)
1100
+ self.bond_angle_encoder=BondAngleFloatRBF(bond_angle_float_names,emb_dim)
1101
+
1102
+ # List of GNNs
1103
+ self.convs = torch.nn.ModuleList()
1104
+ self.convs_bond_angle=torch.nn.ModuleList()
1105
+ self.convs_bond_float=torch.nn.ModuleList()
1106
+ self.convs_bond_embeding=torch.nn.ModuleList()
1107
+ self.convs_angle_float=torch.nn.ModuleList()
1108
+ self.batch_norms = torch.nn.ModuleList()
1109
+ self.batch_norms_ba = torch.nn.ModuleList()
1110
+ for layer in range(num_layers):
1111
+ self.convs.append(GINConv(emb_dim))
1112
+ self.convs_bond_angle.append(GINConv(emb_dim))
1113
+ self.convs_bond_embeding.append(BondEncoder(emb_dim))
1114
+ self.convs_bond_float.append(BondFloatRBF(bond_float_names,emb_dim))
1115
+ self.convs_angle_float.append(BondAngleFloatRBF(bond_angle_float_names,emb_dim))
1116
+ self.batch_norms.append(torch.nn.BatchNorm1d(emb_dim))
1117
+ self.batch_norms_ba.append(torch.nn.BatchNorm1d(emb_dim))
1118
+
1119
+ def forward(self, batched_atom_bond,batched_bond_angle):
1120
+ x, edge_index, edge_attr = batched_atom_bond.x, batched_atom_bond.edge_index, batched_atom_bond.edge_attr
1121
+ edge_index_ba,edge_attr_ba= batched_bond_angle.edge_index, batched_bond_angle.edge_attr
1122
+ # computing input node embedding
1123
+ h_list = [self.atom_encoder(x)] # 先将类别型原子属性转化为原子嵌入
1124
+
1125
+ if Use_geometry_enhanced==True:
1126
+ h_list_ba=[self.bond_float_encoder(edge_attr[:,len(bond_id_names):edge_attr.shape[1]+1].to(torch.float32))+self.bond_encoder(edge_attr[:,0:len(bond_id_names)].to(torch.int64))]
1127
+ for layer in range(self.num_layers):
1128
+ h = self.convs[layer](h_list[layer], edge_index, h_list_ba[layer])
1129
+ cur_h_ba=self.convs_bond_embeding[layer](edge_attr[:,0:len(bond_id_names)].to(torch.int64))+self.convs_bond_float[layer](edge_attr[:,len(bond_id_names):edge_attr.shape[1]+1].to(torch.float32))
1130
+ cur_angle_hidden=self.convs_angle_float[layer](edge_attr_ba)
1131
+ h_ba=self.convs_bond_angle[layer](cur_h_ba, edge_index_ba, cur_angle_hidden)
1132
+
1133
+ if layer == self.num_layers - 1:
1134
+ # remove relu for the last layer
1135
+ h = F.dropout(h, self.drop_ratio, training=self.training)
1136
+ h_ba = F.dropout(h_ba, self.drop_ratio, training=self.training)
1137
+ else:
1138
+ h = F.dropout(F.relu(h), self.drop_ratio, training=self.training)
1139
+ h_ba = F.dropout(F.relu(h_ba), self.drop_ratio, training=self.training)
1140
+ if self.residual:
1141
+ h += h_list[layer]
1142
+ h_ba+=h_list_ba[layer]
1143
+ h_list.append(h)
1144
+ h_list_ba.append(h_ba)
1145
+
1146
+
1147
+ # Different implementations of Jk-concat
1148
+ if self.JK == "last":
1149
+ node_representation = h_list[-1]
1150
+ edge_representation = h_list_ba[-1]
1151
+ elif self.JK == "sum":
1152
+ node_representation = 0
1153
+ edge_representation = 0
1154
+ for layer in range(self.num_layers + 1):
1155
+ node_representation += h_list[layer]
1156
+ edge_representation += h_list_ba[layer]
1157
+
1158
+ return node_representation,edge_representation
1159
+ if Use_geometry_enhanced==False:
1160
+ for layer in range(self.num_layers):
1161
+ h = self.convs[layer](h_list[layer], edge_index,
1162
+ self.convs_bond_embeding[layer](edge_attr[:, 0:len(bond_id_names)].to(torch.int64)) +
1163
+ self.convs_bond_float[layer](
1164
+ edge_attr[:, len(bond_id_names):edge_attr.shape[1] + 1].to(torch.float32)))
1165
+ h = self.batch_norms[layer](h)
1166
+ if layer == self.num_layers - 1:
1167
+ # remove relu for the last layer
1168
+ h = F.dropout(h, self.drop_ratio, training=self.training)
1169
+ else:
1170
+ h = F.dropout(F.relu(h), self.drop_ratio, training=self.training)
1171
+
1172
+ if self.residual:
1173
+ h += h_list[layer]
1174
+
1175
+ h_list.append(h)
1176
+
1177
+ # Different implementations of Jk-concat
1178
+ if self.JK == "last":
1179
+ node_representation = h_list[-1]
1180
+ elif self.JK == "sum":
1181
+ node_representation = 0
1182
+ for layer in range(self.num_layers + 1):
1183
+ node_representation += h_list[layer]
1184
+
1185
+ return node_representation
1186
+
1187
+ class GINGraphPooling(nn.Module):
1188
+
1189
+ def __init__(self, num_tasks=1, num_layers=5, emb_dim=300, residual=False, drop_ratio=0, JK="last", graph_pooling="attention",
1190
+ descriptor_dim=1781):
1191
+ """GIN Graph Pooling Module
1192
+
1193
+ 此模块首先采用GINNodeEmbedding模块对图上每一个节点做嵌入,然后对节点嵌入做池化得到图的嵌入,最后用一层线性变换得到图的最终的表示(graph representation)。
1194
+
1195
+ Args:
1196
+ num_tasks (int, optional): number of labels to be predicted. Defaults to 1 (控制了图表示的维度,dimension of graph representation).
1197
+ num_layers (int, optional): number of GINConv layers. Defaults to 5.
1198
+ emb_dim (int, optional): dimension of node embedding. Defaults to 300.
1199
+ residual (bool, optional): adding residual connection or not. Defaults to False.
1200
+ drop_ratio (float, optional): dropout rate. Defaults to 0.
1201
+ JK (str, optional): 可选的值为"last"和"sum"。选"last",只取最后一层的结点的嵌入,选"sum"对各层的结点的嵌入求和。Defaults to "last".
1202
+ graph_pooling (str, optional): pooling method of node embedding. 可选的值为"sum","mean","max","attention"和"set2set"。 Defaults to "sum".
1203
+
1204
+ Out:
1205
+ graph representation
1206
+ """
1207
+ super(GINGraphPooling, self).__init__()
1208
+
1209
+ self.num_layers = num_layers
1210
+ self.drop_ratio = drop_ratio
1211
+ self.JK = JK
1212
+ self.emb_dim = emb_dim
1213
+ self.num_tasks = num_tasks
1214
+ self.descriptor_dim=descriptor_dim
1215
+ if self.num_layers < 2:
1216
+ raise ValueError("Number of GNN layers must be greater than 1.")
1217
+
1218
+ self.gnn_node = GINNodeEmbedding(num_layers, emb_dim, JK=JK, drop_ratio=drop_ratio, residual=residual)
1219
+
1220
+ # Pooling function to generate whole-graph embeddings
1221
+ if graph_pooling == "sum":
1222
+ self.pool = global_add_pool
1223
+
1224
+ elif graph_pooling == "mean":
1225
+ self.pool = global_mean_pool
1226
+
1227
+ elif graph_pooling == "max":
1228
+ self.pool = global_max_pool
1229
+
1230
+ elif graph_pooling == "attention":
1231
+ self.pool = GlobalAttention(gate_nn=nn.Sequential(
1232
+ nn.Linear(emb_dim, emb_dim), nn.BatchNorm1d(emb_dim), nn.ReLU(), nn.Linear(emb_dim, 1)))
1233
+
1234
+
1235
+ elif graph_pooling == "set2set":
1236
+ self.pool = Set2Set(emb_dim, processing_steps=2)
1237
+ else:
1238
+ raise ValueError("Invalid graph pooling type.")
1239
+
1240
+ if graph_pooling == "set2set":
1241
+ self.graph_pred_linear = nn.Linear(self.emb_dim, self.num_tasks)
1242
+ else:
1243
+ self.graph_pred_linear = nn.Linear(self.emb_dim, self.num_tasks)
1244
+
1245
+ self.NN_descriptor = nn.Sequential(nn.Linear(self.descriptor_dim, self.emb_dim),
1246
+ nn.Sigmoid(),
1247
+ nn.Linear(self.emb_dim, self.emb_dim))
1248
+
1249
+ self.sigmoid = nn.Sigmoid()
1250
+
1251
+ def forward(self, batched_atom_bond,batched_bond_angle):
1252
+ if Use_geometry_enhanced==True:
1253
+ h_node,h_node_ba= self.gnn_node(batched_atom_bond,batched_bond_angle)
1254
+ else:
1255
+ h_node= self.gnn_node(batched_atom_bond, batched_bond_angle)
1256
+ h_graph = self.pool(h_node, batched_atom_bond.batch)
1257
+ output = self.graph_pred_linear(h_graph)
1258
+ if self.training:
1259
+ return output,h_graph
1260
+ else:
1261
+ # At inference time, relu is applied to output to ensure positivity
1262
+ return torch.clamp(output, min=0, max=1e8),h_graph
1263
+
1264
+ def mord(mol, nBits=1826, errors_as_zeros=True):
1265
+ try:
1266
+ result = calc(mol)
1267
+ desc_list = [r if not is_missing(r) else 0 for r in result]
1268
+ np_arr = np.array(desc_list)
1269
+ return np_arr
1270
+ except:
1271
+ return np.NaN if not errors_as_zeros else np.zeros((nBits,), dtype=np.float32)
1272
+
1273
+ def load_3D_mol():
1274
+ dir = 'mol_save/'
1275
+ for root, dirs, files in os.walk(dir):
1276
+ file_names = files
1277
+ file_names.sort(key=lambda x: int(x[x.find('_') + 5:x.find(".")])) # 按照前面的数字字符排序
1278
+ mol_save = []
1279
+ for file_name in file_names:
1280
+ mol_save.append(Chem.MolFromMolFile(dir + file_name))
1281
+ return mol_save
1282
+
1283
+ def parse_args():
1284
+ parser = argparse.ArgumentParser(description='Graph data miming with GNN')
1285
+ parser.add_argument('--task_name', type=str, default='GINGraphPooling',
1286
+ help='task name')
1287
+ parser.add_argument('--device', type=int, default=0,
1288
+ help='which gpu to use if any (default: 0)')
1289
+ parser.add_argument('--num_layers', type=int, default=5,
1290
+ help='number of GNN message passing layers (default: 5)')
1291
+ parser.add_argument('--graph_pooling', type=str, default='sum',
1292
+ help='graph pooling strategy mean or sum (default: sum)')
1293
+ parser.add_argument('--emb_dim', type=int, default=128,
1294
+ help='dimensionality of hidden units in GNNs (default: 256)')
1295
+ parser.add_argument('--drop_ratio', type=float, default=0.,
1296
+ help='dropout ratio (default: 0.)')
1297
+ parser.add_argument('--save_test', action='store_true')
1298
+ parser.add_argument('--batch_size', type=int, default=2048,
1299
+ help='input batch size for training (default: 512)')
1300
+ parser.add_argument('--epochs', type=int, default=1000,
1301
+ help='number of epochs to train (default: 100)')
1302
+ parser.add_argument('--weight_decay', type=float, default=0.00001,
1303
+ help='weight decay')
1304
+ parser.add_argument('--early_stop', type=int, default=10,
1305
+ help='early stop (default: 10)')
1306
+ parser.add_argument('--num_workers', type=int, default=0,
1307
+ help='number of workers (default: 0)')
1308
+ parser.add_argument('--dataset_root', type=str, default="dataset",
1309
+ help='dataset root')
1310
+ args = parser.parse_args()
1311
+
1312
+ return args
1313
+
1314
+ def calc_dragon_type_desc(mol):
1315
+ compound_mol = mol
1316
+ compound_MolWt = Descriptors.ExactMolWt(compound_mol)
1317
+ compound_TPSA = Chem.rdMolDescriptors.CalcTPSA(compound_mol)
1318
+ compound_nRotB = Descriptors.NumRotatableBonds(compound_mol) # Number of rotable bonds
1319
+ compound_HBD = Descriptors.NumHDonors(compound_mol) # Number of H bond donors
1320
+ compound_HBA = Descriptors.NumHAcceptors(compound_mol) # Number of H bond acceptors
1321
+ compound_LogP = Descriptors.MolLogP(compound_mol) # LogP
1322
+ return rdMolDescriptors.CalcAUTOCORR3D(mol) + rdMolDescriptors.CalcMORSE(mol) + \
1323
+ rdMolDescriptors.CalcRDF(mol) + rdMolDescriptors.CalcWHIM(mol) + \
1324
+ [compound_MolWt, compound_TPSA, compound_nRotB, compound_HBD, compound_HBA, compound_LogP]
1325
+
1326
+
1327
+ def eval(model, device, loader_atom_bond,loader_bond_angle):
1328
+ model.eval()
1329
+ y_true = []
1330
+ y_pred = []
1331
+ y_pred_10=[]
1332
+ y_pred_90=[]
1333
+
1334
+ with torch.no_grad():
1335
+ for _, batch in enumerate(zip(loader_atom_bond,loader_bond_angle)):
1336
+ batch_atom_bond = batch[0]
1337
+ batch_bond_angle = batch[1]
1338
+ batch_atom_bond = batch_atom_bond.to(device)
1339
+ batch_bond_angle = batch_bond_angle.to(device)
1340
+ pred = model(batch_atom_bond,batch_bond_angle)[0]
1341
+
1342
+ y_true.append(batch_atom_bond.y.detach().cpu().reshape(-1))
1343
+ y_pred.append(pred[:,1].detach().cpu())
1344
+ y_pred_10.append(pred[:,0].detach().cpu())
1345
+ y_pred_90.append(pred[:,2].detach().cpu())
1346
+ y_true = torch.cat(y_true, dim=0)
1347
+ y_pred = torch.cat(y_pred, dim=0)
1348
+ y_pred_10 = torch.cat(y_pred_10, dim=0)
1349
+ y_pred_90 = torch.cat(y_pred_90, dim=0)
1350
+ # plt.plot(y_pred.cpu().data.numpy(),c='blue')
1351
+ # plt.plot(y_pred_10.cpu().data.numpy(),c='yellow')
1352
+ # plt.plot(y_pred_90.cpu().data.numpy(),c='black')
1353
+ # plt.plot(y_true.cpu().data.numpy(),c='red')
1354
+ #plt.show()
1355
+ input_dict = {"y_true": y_true, "y_pred": y_pred}
1356
+ return torch.mean((y_true - y_pred) ** 2).data.numpy()
1357
+
1358
+
1359
+ def cal_prob(prediction):
1360
+ '''
1361
+ calculate the separation probability Sp
1362
+ '''
1363
+ #input prediction=[pred_1,pred_2]
1364
+ #output: Sp
1365
+ a=prediction[0][0]
1366
+ b=prediction[1][0]
1367
+ if a[2]<b[0]:
1368
+ return 1
1369
+ elif a[0]>b[2]:
1370
+ return 1
1371
+ else:
1372
+ length=min(a[2],b[2])-max(a[0],b[0])
1373
+ all=max(a[2],b[2])-min(a[0],b[0])
1374
+ return 1-length/(all)
1375
+
1376
+
1377
+
1378
+ args = parse_args()
1379
+ nn_params = {
1380
+ 'num_tasks': 3,
1381
+ 'num_layers': args.num_layers,
1382
+ 'emb_dim': args.emb_dim,
1383
+ 'drop_ratio': args.drop_ratio,
1384
+ 'graph_pooling': args.graph_pooling,
1385
+ 'descriptor_dim': 1827
1386
+ }
1387
+ device = args.device
1388
+ model = GINGraphPooling(**nn_params).to(device)
1389
+
1390
+
1391
+ '''
1392
+ Given two compounds and predict the RT in different condition
1393
+ '''
1394
+
1395
+
1396
+ def predict_separate(smile_1, smile_2, input_eluent, input_speed, input_column):
1397
+ speed = []
1398
+ eluent = []
1399
+ smiles=[smile_1,smile_2]
1400
+ for i in range(2):
1401
+ speed.append(input_speed)
1402
+ eluent.append(input_eluent)
1403
+
1404
+ column_descriptor = np.load('column_descriptor.npy', allow_pickle=True)
1405
+ predict_column=input_column
1406
+ col_specify = column_specify[predict_column]
1407
+ col_des = np.array(column_descriptor[col_specify[3]])
1408
+ mols = []
1409
+ y_pred = []
1410
+ all_descriptor = []
1411
+ dataset = []
1412
+ for smile in smiles:
1413
+ mol = Chem.MolFromSmiles(smile)
1414
+ mols.append(mol)
1415
+ for smile in smiles:
1416
+ mol = obtain_3D_mol(smile, 'conform')
1417
+ mol = Chem.MolFromMolFile(f"conform.mol")
1418
+ all_descriptor.append(mord(mol))
1419
+ dataset.append(mol_to_geognn_graph_data_MMFF3d(mol))
1420
+
1421
+ for i in range(0, len(dataset)):
1422
+ data = dataset[i]
1423
+ atom_feature = []
1424
+ bond_feature = []
1425
+ for name in atom_id_names:
1426
+ atom_feature.append(data[name])
1427
+ for name in bond_id_names[0:3]:
1428
+ bond_feature.append(data[name])
1429
+ atom_feature = torch.from_numpy(np.array(atom_feature).T).to(torch.int64)
1430
+ bond_feature = torch.from_numpy(np.array(bond_feature).T).to(torch.int64)
1431
+ bond_float_feature = torch.from_numpy(data['bond_length'].astype(np.float32))
1432
+ bond_angle_feature = torch.from_numpy(data['bond_angle'].astype(np.float32))
1433
+ y = torch.Tensor([float(speed[i])])
1434
+ edge_index = torch.from_numpy(data['edges'].T).to(torch.int64)
1435
+ bond_index = torch.from_numpy(data['BondAngleGraph_edges'].T).to(torch.int64)
1436
+
1437
+ prop = torch.ones([bond_feature.shape[0]]) * eluent[i]
1438
+ coated = torch.ones([bond_feature.shape[0]]) * col_specify[0]
1439
+ diameter = torch.ones([bond_feature.shape[0]]) * col_specify[1]
1440
+ immobilized = torch.ones([bond_feature.shape[0]]) * col_specify[2]
1441
+
1442
+ TPSA = torch.ones([bond_angle_feature.shape[0]]) * all_descriptor[i][820] / 100
1443
+ RASA = torch.ones([bond_angle_feature.shape[0]]) * all_descriptor[i][821]
1444
+ RPSA = torch.ones([bond_angle_feature.shape[0]]) * all_descriptor[i][822]
1445
+ MDEC = torch.ones([bond_angle_feature.shape[0]]) * all_descriptor[i][1568]
1446
+ MATS = torch.ones([bond_angle_feature.shape[0]]) * all_descriptor[i][457]
1447
+
1448
+ col_TPSA = torch.ones([bond_angle_feature.shape[0]]) * col_des[820] / 100
1449
+ col_RASA = torch.ones([bond_angle_feature.shape[0]]) * col_des[821]
1450
+ col_RPSA = torch.ones([bond_angle_feature.shape[0]]) * col_des[822]
1451
+ col_MDEC = torch.ones([bond_angle_feature.shape[0]]) * col_des[1568]
1452
+ col_MATS = torch.ones([bond_angle_feature.shape[0]]) * col_des[457]
1453
+
1454
+ bond_feature = torch.cat([bond_feature, coated.reshape(-1, 1)], dim=1)
1455
+ bond_feature = torch.cat([bond_feature, immobilized.reshape(-1, 1)], dim=1)
1456
+ bond_feature = torch.cat([bond_feature, bond_float_feature.reshape(-1, 1)], dim=1)
1457
+ bond_feature = torch.cat([bond_feature, prop.reshape(-1, 1)], dim=1)
1458
+ bond_feature = torch.cat([bond_feature, diameter.reshape(-1, 1)], dim=1)
1459
+
1460
+ bond_angle_feature = torch.cat([bond_angle_feature.reshape(-1, 1), TPSA.reshape(-1, 1)], dim=1)
1461
+ bond_angle_feature = torch.cat([bond_angle_feature, RASA.reshape(-1, 1)], dim=1)
1462
+ bond_angle_feature = torch.cat([bond_angle_feature, RPSA.reshape(-1, 1)], dim=1)
1463
+ bond_angle_feature = torch.cat([bond_angle_feature, MDEC.reshape(-1, 1)], dim=1)
1464
+ bond_angle_feature = torch.cat([bond_angle_feature, MATS.reshape(-1, 1)], dim=1)
1465
+ bond_angle_feature = torch.cat([bond_angle_feature, col_TPSA.reshape(-1, 1)], dim=1)
1466
+ bond_angle_feature = torch.cat([bond_angle_feature, col_RASA.reshape(-1, 1)], dim=1)
1467
+ bond_angle_feature = torch.cat([bond_angle_feature, col_RPSA.reshape(-1, 1)], dim=1)
1468
+ bond_angle_feature = torch.cat([bond_angle_feature, col_MDEC.reshape(-1, 1)], dim=1)
1469
+ bond_angle_feature = torch.cat([bond_angle_feature, col_MATS.reshape(-1, 1)], dim=1)
1470
+
1471
+ data_atom_bond = Data(atom_feature, edge_index, bond_feature, y)
1472
+ data_bond_angle = Data(edge_index=bond_index, edge_attr=bond_angle_feature)
1473
+ model.load_state_dict(
1474
+ torch.load(f'GeoGNN_model.pth',map_location=torch.device('cpu')))
1475
+ model.eval()
1476
+
1477
+ pred, h_graph = model(data_atom_bond.to(device), data_bond_angle.to(device))
1478
+
1479
+ y_pred.append(pred.detach().cpu().data.numpy() / speed[i])
1480
+ if input_speed==0:
1481
+ out_put='Speed cannot be 0!'
1482
+ else:
1483
+ Sp=cal_prob(y_pred)
1484
+ output_1=f'For smile_1,\n the predicted value is: {str(np.round(y_pred[0][0][1],3))}\n'
1485
+ output_2 = f'For smile_2,\n the predicted value is: {str(np.round(y_pred[1][0][1],3))}\n'
1486
+ output_3=f'The separation probability is: {str(np.round(Sp,3))}'
1487
+ out_put=output_1+output_2+output_3
1488
+ return out_put
1489
+
1490
+
1491
+ def column_recommendation(smile_1, smile_2, input_eluent, input_speed):
1492
+ speed = []
1493
+ eluent = []
1494
+ Prediction = []
1495
+ Sp = []
1496
+ smiles = [smile_1, smile_2]
1497
+ for i in range(2):
1498
+ speed.append(input_speed)
1499
+ eluent.append(input_eluent)
1500
+ for predict_column in column_specify.keys():
1501
+ column_descriptor = np.load('column_descriptor.npy', allow_pickle=True)
1502
+ col_specify = column_specify[predict_column]
1503
+ col_des = np.array(column_descriptor[col_specify[3]])
1504
+ mols = []
1505
+ y_pred = []
1506
+ all_descriptor = []
1507
+ dataset = []
1508
+ for smile in smiles:
1509
+ mol = Chem.MolFromSmiles(smile)
1510
+ mols.append(mol)
1511
+ for smile in smiles:
1512
+ mol = obtain_3D_mol(smile, 'conform')
1513
+ mol = Chem.MolFromMolFile(f"conform.mol")
1514
+ all_descriptor.append(mord(mol))
1515
+ dataset.append(mol_to_geognn_graph_data_MMFF3d(mol))
1516
+
1517
+ for i in range(0, len(dataset)):
1518
+ data = dataset[i]
1519
+ atom_feature = []
1520
+ bond_feature = []
1521
+ for name in atom_id_names:
1522
+ atom_feature.append(data[name])
1523
+ for name in bond_id_names[0:3]:
1524
+ bond_feature.append(data[name])
1525
+ atom_feature = torch.from_numpy(np.array(atom_feature).T).to(torch.int64)
1526
+ bond_feature = torch.from_numpy(np.array(bond_feature).T).to(torch.int64)
1527
+ bond_float_feature = torch.from_numpy(data['bond_length'].astype(np.float32))
1528
+ bond_angle_feature = torch.from_numpy(data['bond_angle'].astype(np.float32))
1529
+ y = torch.Tensor([float(speed[i])])
1530
+ edge_index = torch.from_numpy(data['edges'].T).to(torch.int64)
1531
+ bond_index = torch.from_numpy(data['BondAngleGraph_edges'].T).to(torch.int64)
1532
+
1533
+ prop = torch.ones([bond_feature.shape[0]]) * eluent[i]
1534
+ coated = torch.ones([bond_feature.shape[0]]) * col_specify[0]
1535
+ diameter = torch.ones([bond_feature.shape[0]]) * col_specify[1]
1536
+ immobilized = torch.ones([bond_feature.shape[0]]) * col_specify[2]
1537
+
1538
+ TPSA = torch.ones([bond_angle_feature.shape[0]]) * all_descriptor[i][820] / 100
1539
+ RASA = torch.ones([bond_angle_feature.shape[0]]) * all_descriptor[i][821]
1540
+ RPSA = torch.ones([bond_angle_feature.shape[0]]) * all_descriptor[i][822]
1541
+ MDEC = torch.ones([bond_angle_feature.shape[0]]) * all_descriptor[i][1568]
1542
+ MATS = torch.ones([bond_angle_feature.shape[0]]) * all_descriptor[i][457]
1543
+
1544
+ col_TPSA = torch.ones([bond_angle_feature.shape[0]]) * col_des[820] / 100
1545
+ col_RASA = torch.ones([bond_angle_feature.shape[0]]) * col_des[821]
1546
+ col_RPSA = torch.ones([bond_angle_feature.shape[0]]) * col_des[822]
1547
+ col_MDEC = torch.ones([bond_angle_feature.shape[0]]) * col_des[1568]
1548
+ col_MATS = torch.ones([bond_angle_feature.shape[0]]) * col_des[457]
1549
+
1550
+ bond_feature = torch.cat([bond_feature, coated.reshape(-1, 1)], dim=1)
1551
+ bond_feature = torch.cat([bond_feature, immobilized.reshape(-1, 1)], dim=1)
1552
+ bond_feature = torch.cat([bond_feature, bond_float_feature.reshape(-1, 1)], dim=1)
1553
+ bond_feature = torch.cat([bond_feature, prop.reshape(-1, 1)], dim=1)
1554
+ bond_feature = torch.cat([bond_feature, diameter.reshape(-1, 1)], dim=1)
1555
+
1556
+ bond_angle_feature = torch.cat([bond_angle_feature.reshape(-1, 1), TPSA.reshape(-1, 1)], dim=1)
1557
+ bond_angle_feature = torch.cat([bond_angle_feature, RASA.reshape(-1, 1)], dim=1)
1558
+ bond_angle_feature = torch.cat([bond_angle_feature, RPSA.reshape(-1, 1)], dim=1)
1559
+ bond_angle_feature = torch.cat([bond_angle_feature, MDEC.reshape(-1, 1)], dim=1)
1560
+ bond_angle_feature = torch.cat([bond_angle_feature, MATS.reshape(-1, 1)], dim=1)
1561
+ bond_angle_feature = torch.cat([bond_angle_feature, col_TPSA.reshape(-1, 1)], dim=1)
1562
+ bond_angle_feature = torch.cat([bond_angle_feature, col_RASA.reshape(-1, 1)], dim=1)
1563
+ bond_angle_feature = torch.cat([bond_angle_feature, col_RPSA.reshape(-1, 1)], dim=1)
1564
+ bond_angle_feature = torch.cat([bond_angle_feature, col_MDEC.reshape(-1, 1)], dim=1)
1565
+ bond_angle_feature = torch.cat([bond_angle_feature, col_MATS.reshape(-1, 1)], dim=1)
1566
+
1567
+ data_atom_bond = Data(atom_feature, edge_index, bond_feature, y)
1568
+ data_bond_angle = Data(edge_index=bond_index, edge_attr=bond_angle_feature)
1569
+ model.load_state_dict(
1570
+ torch.load(f'GeoGNN_model.pth', map_location=torch.device('cpu')))
1571
+ model.eval()
1572
+
1573
+ pred, h_graph = model(data_atom_bond.to(device), data_bond_angle.to(device))
1574
+
1575
+ y_pred.append(pred.detach().cpu().data.numpy() / speed[i])
1576
+ Prediction.append(y_pred)
1577
+ Sp.append(cal_prob(y_pred))
1578
+ Prediction_1=np.squeeze(np.array(Prediction))[:,0,1]
1579
+ Prediction_2 = np.squeeze(np.array(Prediction))[:, 1, 1]
1580
+ Sp=np.array(Sp)
1581
+ result=pd.DataFrame({'Column_name':column_specify.keys(),'RT_1':Prediction_1,'RT_2':Prediction_2,'Separation_probability':Sp})
1582
+ result= result[result.loc[:]!=0].dropna()
1583
+ result['RT_1'] = result['RT_1'].apply(lambda x: format(x, '.2f'))
1584
+ result['RT_2'] = result['RT_2'].apply(lambda x: format(x, '.2f'))
1585
+ result = result.sort_values(by="Separation_probability", ascending=False)
1586
+ result['Separation_probability'] = result['Separation_probability'].apply(lambda x: format(x, '.2%'))
1587
+
1588
+ return result
1589
+
1590
+
1591
+
1592
+ if __name__=='__main__':
1593
+ column_recommendation('CC','CCCC',0.1,0.1)
1594
+ demo_1=gr.Interface(fn=predict_separate, inputs=["text", "text", "number", "number",
1595
+ gr.Dropdown(['ADH', 'ODH', 'IC', 'IA', 'OJH', 'ASH', 'IC3',
1596
+ 'IE', 'ID', 'OD3', 'IB', 'AD', 'AD3', 'IF', 'OD',
1597
+ 'AS', 'OJ3', 'IG', 'AZ', 'IAH', 'OJ',
1598
+ 'ICH', 'OZ3', 'IF3', 'IAU'], label="Column type",
1599
+ info="Choose a HPLC column")], outputs=['text'])
1600
+ demo_2=gr.Interface(fn=column_recommendation, inputs=["text", "text", "number", "number"],
1601
+ outputs=['dataframe'])
1602
+ demo=gr.TabbedInterface([demo_1, demo_2], ["Single prediction", "Column recommendation"])
1603
+ demo.launch()
1604
+
1605
+
1606
+
1607
+
1608
+
1609
+
1610
+
column_descriptor.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b263dd8713acc0b863b76ba295f5d5828b350323f4dc304115d196b9cc1fa969
3
+ size 365328