|
import torch |
|
from torch_geometric.nn import MessagePassing |
|
from rdkit.Chem import Descriptors |
|
from torch_geometric.data import Data |
|
import argparse |
|
import warnings |
|
from rdkit.Chem.Descriptors import rdMolDescriptors |
|
import pandas as pd |
|
import os |
|
from mordred import Calculator, descriptors, is_missing |
|
from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool, GlobalAttention, Set2Set |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import numpy as np |
|
from rdkit import Chem |
|
from rdkit.Chem import AllChem |
|
from rdkit.Chem import rdchem |
|
import gradio as gr |
|
DAY_LIGHT_FG_SMARTS_LIST = [ |
|
|
|
"[CX4]", |
|
"[$([CX2](=C)=C)]", |
|
"[$([CX3]=[CX3])]", |
|
"[$([CX2]#C)]", |
|
|
|
"[CX3]=[OX1]", |
|
"[$([CX3]=[OX1]),$([CX3+]-[OX1-])]", |
|
"[CX3](=[OX1])C", |
|
"[OX1]=CN", |
|
"[CX3](=[OX1])O", |
|
"[CX3](=[OX1])[F,Cl,Br,I]", |
|
"[CX3H1](=O)[#6]", |
|
"[CX3](=[OX1])[OX2][CX3](=[OX1])", |
|
"[NX3][CX3](=[OX1])[#6]", |
|
"[NX3][CX3]=[NX3+]", |
|
"[NX3,NX4+][CX3](=[OX1])[OX2,OX1-]", |
|
"[NX3][CX3](=[OX1])[OX2H0]", |
|
"[NX3,NX4+][CX3](=[OX1])[OX2H,OX1-]", |
|
"[CX3](=O)[O-]", |
|
"[CX3](=[OX1])(O)O", |
|
"[CX3](=[OX1])([OX2])[OX2H,OX1H0-1]", |
|
"C[OX2][CX3](=[OX1])[OX2]C", |
|
"[CX3](=O)[OX2H1]", |
|
"[CX3](=O)[OX1H0-,OX2H1]", |
|
"[NX3][CX2]#[NX1]", |
|
"[#6][CX3](=O)[OX2H0][#6]", |
|
"[#6][CX3](=O)[#6]", |
|
"[OD2]([#6])[#6]", |
|
|
|
"[H]", |
|
"[!#1]", |
|
"[H+]", |
|
"[+H]", |
|
"[!H]", |
|
|
|
"[NX3;H2,H1;!$(NC=O)]", |
|
"[NX3][CX3]=[CX3]", |
|
"[NX3;H2;!$(NC=[!#6]);!$(NC#[!#6])][#6]", |
|
"[NX3;H2,H1;!$(NC=O)].[NX3;H2,H1;!$(NC=O)]", |
|
"[NX3][$(C=C),$(cc)]", |
|
"[NX3,NX4+][CX4H]([*])[CX3](=[OX1])[O,N]", |
|
"[NX3H2,NH3X4+][CX4H]([*])[CX3](=[OX1])[NX3,NX4+][CX4H]([*])[CX3](=[OX1])[OX2H,OX1-]", |
|
"[$([NX3H2,NX4H3+]),$([NX3H](C)(C))][CX4H]([*])[CX3](=[OX1])[OX2H,OX1-,N]", |
|
"[CH3X4]", |
|
"[CH2X4][CH2X4][CH2X4][NHX3][CH0X3](=[NH2X3+,NHX2+0])[NH2X3]", |
|
"[CH2X4][CX3](=[OX1])[NX3H2]", |
|
"[CH2X4][CX3](=[OX1])[OH0-,OH]", |
|
"[CH2X4][SX2H,SX1H0-]", |
|
"[CH2X4][CH2X4][CX3](=[OX1])[OH0-,OH]", |
|
"[$([$([NX3H2,NX4H3+]),$([NX3H](C)(C))][CX4H2][CX3](=[OX1])[OX2H,OX1-,N])]", |
|
"[CH2X4][#6X3]1:[$([#7X3H+,#7X2H0+0]:[#6X3H]:[#7X3H]),$([#7X3H])]:[#6X3H]:\ |
|
[$([#7X3H+,#7X2H0+0]:[#6X3H]:[#7X3H]),$([#7X3H])]:[#6X3H]1", |
|
"[CHX4]([CH3X4])[CH2X4][CH3X4]", |
|
"[CH2X4][CHX4]([CH3X4])[CH3X4]", |
|
"[CH2X4][CH2X4][CH2X4][CH2X4][NX4+,NX3+0]", |
|
"[CH2X4][CH2X4][SX2][CH3X4]", |
|
"[CH2X4][cX3]1[cX3H][cX3H][cX3H][cX3H][cX3H]1", |
|
"[$([NX3H,NX4H2+]),$([NX3](C)(C)(C))]1[CX4H]([CH2][CH2][CH2]1)[CX3](=[OX1])[OX2H,OX1-,N]", |
|
"[CH2X4][OX2H]", |
|
"[NX3][CX3]=[SX1]", |
|
"[CHX4]([CH3X4])[OX2H]", |
|
"[CH2X4][cX3]1[cX3H][nX3H][cX3]2[cX3H][cX3H][cX3H][cX3H][cX3]12", |
|
"[CH2X4][cX3]1[cX3H][cX3H][cX3]([OHX2,OH0X1-])[cX3H][cX3H]1", |
|
"[CHX4]([CH3X4])[CH3X4]", |
|
"N[CX4H2][CX3](=[OX1])[O,N]", |
|
"N1[CX4H]([CH2][CH2][CH2]1)[CX3](=[OX1])[O,N]", |
|
"[$(*-[NX2-]-[NX2+]#[NX1]),$(*-[NX2]=[NX2+]=[NX1-])]", |
|
"[$([NX1-]=[NX2+]=[NX1-]),$([NX1]#[NX2+]-[NX1-2])]", |
|
"[#7]", |
|
"[NX2]=N", |
|
"[NX2]=[NX2]", |
|
"[$([NX2]=[NX3+]([O-])[#6]),$([NX2]=[NX3+0](=[O])[#6])]", |
|
"[$([#6]=[N+]=[N-]),$([#6-]-[N+]#[N])]", |
|
"[$([nr5]:[nr5,or5,sr5]),$([nr5]:[cr5]:[nr5,or5,sr5])]", |
|
"[NX3][NX3]", |
|
"[NX3][NX2]=[*]", |
|
"[CX3;$([C]([#6])[#6]),$([CH][#6])]=[NX2][#6]", |
|
"[$([CX3]([#6])[#6]),$([CX3H][#6])]=[$([NX2][#6]),$([NX2H])]", |
|
"[NX3+]=[CX3]", |
|
"[CX3](=[OX1])[NX3H][CX3](=[OX1])", |
|
"[CX3](=[OX1])[NX3H0]([#6])[CX3](=[OX1])", |
|
"[CX3](=[OX1])[NX3H0]([NX3H0]([CX3](=[OX1]))[CX3](=[OX1]))[CX3](=[OX1])", |
|
"[$([NX3](=[OX1])(=[OX1])O),$([NX3+]([OX1-])(=[OX1])O)]", |
|
"[$([OX1]=[NX3](=[OX1])[OX1-]),$([OX1]=[NX3+]([OX1-])[OX1-])]", |
|
"[NX1]#[CX2]", |
|
"[CX1-]#[NX2+]", |
|
"[$([NX3](=O)=O),$([NX3+](=O)[O-])][!#8]", |
|
"[$([NX3](=O)=O),$([NX3+](=O)[O-])][!#8].[$([NX3](=O)=O),$([NX3+](=O)[O-])][!#8]", |
|
"[NX2]=[OX1]", |
|
"[$([#7+][OX1-]),$([#7v5]=[OX1]);!$([#7](~[O])~[O]);!$([#7]=[#7])]", |
|
|
|
"[OX2H]", |
|
"[#6][OX2H]", |
|
"[OX2H][CX3]=[OX1]", |
|
"[OX2H]P", |
|
"[OX2H][#6X3]=[#6]", |
|
"[OX2H][cX3]:[c]", |
|
"[OX2H][$(C=C),$(cc)]", |
|
"[$([OH]-*=[!#6])]", |
|
"[OX2,OX1-][OX2,OX1-]", |
|
|
|
"[$(P(=[OX1])([$([OX2H]),$([OX1-]),$([OX2]P)])([$([OX2H]),$([OX1-]),\ |
|
$([OX2]P)])[$([OX2H]),$([OX1-]),$([OX2]P)]),$([P+]([OX1-])([$([OX2H]),$([OX1-])\ |
|
,$([OX2]P)])([$([OX2H]),$([OX1-]),$([OX2]P)])[$([OX2H]),$([OX1-]),$([OX2]P)])]", |
|
"[$(P(=[OX1])([OX2][#6])([$([OX2H]),$([OX1-]),$([OX2][#6])])[$([OX2H]),\ |
|
$([OX1-]),$([OX2][#6]),$([OX2]P)]),$([P+]([OX1-])([OX2][#6])([$([OX2H]),$([OX1-]),\ |
|
$([OX2][#6])])[$([OX2H]),$([OX1-]),$([OX2][#6]),$([OX2]P)])]", |
|
|
|
"[S-][CX3](=S)[#6]", |
|
"[#6X3](=[SX1])([!N])[!N]", |
|
"[SX2]", |
|
"[#16X2H]", |
|
"[#16!H0]", |
|
"[#16X2H0]", |
|
"[#16X2H0][!#16]", |
|
"[#16X2H0][#16X2H0]", |
|
"[#16X2H0][!#16].[#16X2H0][!#16]", |
|
"[$([#16X3](=[OX1])[OX2H0]),$([#16X3+]([OX1-])[OX2H0])]", |
|
"[$([#16X3](=[OX1])[OX2H,OX1H0-]),$([#16X3+]([OX1-])[OX2H,OX1H0-])]", |
|
"[$([#16X4](=[OX1])=[OX1]),$([#16X4+2]([OX1-])[OX1-])]", |
|
"[$([#16X4](=[OX1])(=[OX1])([#6])[#6]),$([#16X4+2]([OX1-])([OX1-])([#6])[#6])]", |
|
"[$([#16X4](=[OX1])(=[OX1])([#6])[OX2H,OX1H0-]),$([#16X4+2]([OX1-])([OX1-])([#6])[OX2H,OX1H0-])]", |
|
"[$([#16X4](=[OX1])(=[OX1])([#6])[OX2H0]),$([#16X4+2]([OX1-])([OX1-])([#6])[OX2H0])]", |
|
"[$([#16X4]([NX3])(=[OX1])(=[OX1])[#6]),$([#16X4+2]([NX3])([OX1-])([OX1-])[#6])]", |
|
"[SX4](C)(C)(=O)=N", |
|
"[$([SX4](=[OX1])(=[OX1])([!O])[NX3]),$([SX4+2]([OX1-])([OX1-])([!O])[NX3])]", |
|
"[$([#16X3]=[OX1]),$([#16X3+][OX1-])]", |
|
"[$([#16X3](=[OX1])([#6])[#6]),$([#16X3+]([OX1-])([#6])[#6])]", |
|
"[$([#16X4](=[OX1])(=[OX1])([OX2H,OX1H0-])[OX2][#6]),$([#16X4+2]([OX1-])([OX1-])([OX2H,OX1H0-])[OX2][#6])]", |
|
"[$([SX4](=O)(=O)(O)O),$([SX4+2]([O-])([O-])(O)O)]", |
|
"[$([#16X4](=[OX1])(=[OX1])([OX2][#6])[OX2][#6]),$([#16X4](=[OX1])(=[OX1])([OX2][#6])[OX2][#6])]", |
|
"[$([#16X4]([NX3])(=[OX1])(=[OX1])[OX2][#6]),$([#16X4+2]([NX3])([OX1-])([OX1-])[OX2][#6])]", |
|
"[$([#16X4]([NX3])(=[OX1])(=[OX1])[OX2H,OX1H0-]),$([#16X4+2]([NX3])([OX1-])([OX1-])[OX2H,OX1H0-])]", |
|
"[#16X2][OX2H,OX1H0-]", |
|
"[#16X2][OX2H0]", |
|
|
|
"[#6][F,Cl,Br,I]", |
|
"[F,Cl,Br,I]", |
|
"[F,Cl,Br,I].[F,Cl,Br,I].[F,Cl,Br,I]", |
|
] |
|
|
|
|
|
def get_gasteiger_partial_charges(mol, n_iter=12): |
|
""" |
|
Calculates list of gasteiger partial charges for each atom in mol object. |
|
Args: |
|
mol: rdkit mol object. |
|
n_iter(int): number of iterations. Default 12. |
|
Returns: |
|
list of computed partial charges for each atom. |
|
""" |
|
Chem.rdPartialCharges.ComputeGasteigerCharges(mol, nIter=n_iter, |
|
throwOnParamFailure=True) |
|
partial_charges = [float(a.GetProp('_GasteigerCharge')) for a in |
|
mol.GetAtoms()] |
|
return partial_charges |
|
|
|
|
|
def create_standardized_mol_id(smiles): |
|
""" |
|
Args: |
|
smiles: smiles sequence. |
|
Returns: |
|
inchi. |
|
""" |
|
if check_smiles_validity(smiles): |
|
|
|
smiles = AllChem.MolToSmiles(AllChem.MolFromSmiles(smiles), |
|
isomericSmiles=False) |
|
mol = AllChem.MolFromSmiles(smiles) |
|
if not mol is None: |
|
if '.' in smiles: |
|
mol_species_list = split_rdkit_mol_obj(mol) |
|
largest_mol = get_largest_mol(mol_species_list) |
|
inchi = AllChem.MolToInchi(largest_mol) |
|
else: |
|
inchi = AllChem.MolToInchi(mol) |
|
return inchi |
|
else: |
|
return |
|
else: |
|
return |
|
|
|
|
|
def check_smiles_validity(smiles): |
|
""" |
|
Check whether the smile can't be converted to rdkit mol object. |
|
""" |
|
try: |
|
m = Chem.MolFromSmiles(smiles) |
|
if m: |
|
return True |
|
else: |
|
return False |
|
except Exception as e: |
|
return False |
|
|
|
|
|
def split_rdkit_mol_obj(mol): |
|
""" |
|
Split rdkit mol object containing multiple species or one species into a |
|
list of mol objects or a list containing a single object respectively. |
|
Args: |
|
mol: rdkit mol object. |
|
""" |
|
smiles = AllChem.MolToSmiles(mol, isomericSmiles=True) |
|
smiles_list = smiles.split('.') |
|
mol_species_list = [] |
|
for s in smiles_list: |
|
if check_smiles_validity(s): |
|
mol_species_list.append(AllChem.MolFromSmiles(s)) |
|
return mol_species_list |
|
|
|
|
|
def get_largest_mol(mol_list): |
|
""" |
|
Given a list of rdkit mol objects, returns mol object containing the |
|
largest num of atoms. If multiple containing largest num of atoms, |
|
picks the first one. |
|
Args: |
|
mol_list(list): a list of rdkit mol object. |
|
Returns: |
|
the largest mol. |
|
""" |
|
num_atoms_list = [len(m.GetAtoms()) for m in mol_list] |
|
largest_mol_idx = num_atoms_list.index(max(num_atoms_list)) |
|
return mol_list[largest_mol_idx] |
|
|
|
|
|
def rdchem_enum_to_list(values): |
|
"""values = {0: rdkit.Chem.rdchem.ChiralType.CHI_UNSPECIFIED, |
|
1: rdkit.Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW, |
|
2: rdkit.Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW, |
|
3: rdkit.Chem.rdchem.ChiralType.CHI_OTHER} |
|
""" |
|
return [values[i] for i in range(len(values))] |
|
|
|
|
|
def safe_index(alist, elem): |
|
""" |
|
Return index of element e in list l. If e is not present, return the last index |
|
""" |
|
try: |
|
return alist.index(elem) |
|
except ValueError: |
|
return len(alist) - 1 |
|
|
|
|
|
def get_atom_feature_dims(list_acquired_feature_names): |
|
""" tbd |
|
""" |
|
return list(map(len, [CompoundKit.atom_vocab_dict[name] for name in list_acquired_feature_names])) |
|
|
|
|
|
def get_bond_feature_dims(list_acquired_feature_names): |
|
""" tbd |
|
""" |
|
list_bond_feat_dim = list(map(len, [CompoundKit.bond_vocab_dict[name] for name in list_acquired_feature_names])) |
|
|
|
return [_l + 1 for _l in list_bond_feat_dim] |
|
|
|
|
|
class CompoundKit(object): |
|
""" |
|
CompoundKit |
|
""" |
|
atom_vocab_dict = { |
|
"atomic_num": list(range(1, 119)) + ['misc'], |
|
"chiral_tag": rdchem_enum_to_list(rdchem.ChiralType.values), |
|
"degree": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 'misc'], |
|
"explicit_valence": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 'misc'], |
|
"formal_charge": [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 'misc'], |
|
"hybridization": rdchem_enum_to_list(rdchem.HybridizationType.values), |
|
"implicit_valence": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 'misc'], |
|
"is_aromatic": [0, 1], |
|
"total_numHs": [0, 1, 2, 3, 4, 5, 6, 7, 8, 'misc'], |
|
'num_radical_e': [0, 1, 2, 3, 4, 'misc'], |
|
'atom_is_in_ring': [0, 1], |
|
'valence_out_shell': [0, 1, 2, 3, 4, 5, 6, 7, 8, 'misc'], |
|
'in_num_ring_with_size3': [0, 1, 2, 3, 4, 5, 6, 7, 8, 'misc'], |
|
'in_num_ring_with_size4': [0, 1, 2, 3, 4, 5, 6, 7, 8, 'misc'], |
|
'in_num_ring_with_size5': [0, 1, 2, 3, 4, 5, 6, 7, 8, 'misc'], |
|
'in_num_ring_with_size6': [0, 1, 2, 3, 4, 5, 6, 7, 8, 'misc'], |
|
'in_num_ring_with_size7': [0, 1, 2, 3, 4, 5, 6, 7, 8, 'misc'], |
|
'in_num_ring_with_size8': [0, 1, 2, 3, 4, 5, 6, 7, 8, 'misc'], |
|
} |
|
bond_vocab_dict = { |
|
"bond_dir": rdchem_enum_to_list(rdchem.BondDir.values), |
|
"bond_type": rdchem_enum_to_list(rdchem.BondType.values), |
|
"is_in_ring": [0, 1], |
|
|
|
'bond_stereo': rdchem_enum_to_list(rdchem.BondStereo.values), |
|
'is_conjugated': [0, 1], |
|
} |
|
|
|
atom_float_names = ["van_der_waals_radis", "partial_charge", 'mass'] |
|
|
|
|
|
|
|
day_light_fg_smarts_list = DAY_LIGHT_FG_SMARTS_LIST |
|
day_light_fg_mo_list = [Chem.MolFromSmarts(smarts) for smarts in day_light_fg_smarts_list] |
|
|
|
morgan_fp_N = 200 |
|
morgan2048_fp_N = 2048 |
|
maccs_fp_N = 167 |
|
|
|
period_table = Chem.GetPeriodicTable() |
|
|
|
|
|
|
|
@staticmethod |
|
def get_atom_value(atom, name): |
|
"""get atom values""" |
|
if name == 'atomic_num': |
|
return atom.GetAtomicNum() |
|
elif name == 'chiral_tag': |
|
return atom.GetChiralTag() |
|
elif name == 'degree': |
|
return atom.GetDegree() |
|
elif name == 'explicit_valence': |
|
return atom.GetExplicitValence() |
|
elif name == 'formal_charge': |
|
return atom.GetFormalCharge() |
|
elif name == 'hybridization': |
|
return atom.GetHybridization() |
|
elif name == 'implicit_valence': |
|
return atom.GetImplicitValence() |
|
elif name == 'is_aromatic': |
|
return int(atom.GetIsAromatic()) |
|
elif name == 'mass': |
|
return int(atom.GetMass()) |
|
elif name == 'total_numHs': |
|
return atom.GetTotalNumHs() |
|
elif name == 'num_radical_e': |
|
return atom.GetNumRadicalElectrons() |
|
elif name == 'atom_is_in_ring': |
|
return int(atom.IsInRing()) |
|
elif name == 'valence_out_shell': |
|
return CompoundKit.period_table.GetNOuterElecs(atom.GetAtomicNum()) |
|
else: |
|
raise ValueError(name) |
|
|
|
@staticmethod |
|
def get_atom_feature_id(atom, name): |
|
"""get atom features id""" |
|
assert name in CompoundKit.atom_vocab_dict, "%s not found in atom_vocab_dict" % name |
|
return safe_index(CompoundKit.atom_vocab_dict[name], CompoundKit.get_atom_value(atom, name)) |
|
|
|
@staticmethod |
|
def get_atom_feature_size(name): |
|
"""get atom features size""" |
|
assert name in CompoundKit.atom_vocab_dict, "%s not found in atom_vocab_dict" % name |
|
return len(CompoundKit.atom_vocab_dict[name]) |
|
|
|
|
|
|
|
@staticmethod |
|
def get_bond_value(bond, name): |
|
"""get bond values""" |
|
if name == 'bond_dir': |
|
return bond.GetBondDir() |
|
elif name == 'bond_type': |
|
return bond.GetBondType() |
|
elif name == 'is_in_ring': |
|
return int(bond.IsInRing()) |
|
elif name == 'is_conjugated': |
|
return int(bond.GetIsConjugated()) |
|
elif name == 'bond_stereo': |
|
return bond.GetStereo() |
|
else: |
|
raise ValueError(name) |
|
|
|
@staticmethod |
|
def get_bond_feature_id(bond, name): |
|
"""get bond features id""" |
|
assert name in CompoundKit.bond_vocab_dict, "%s not found in bond_vocab_dict" % name |
|
return safe_index(CompoundKit.bond_vocab_dict[name], CompoundKit.get_bond_value(bond, name)) |
|
|
|
@staticmethod |
|
def get_bond_feature_size(name): |
|
"""get bond features size""" |
|
assert name in CompoundKit.bond_vocab_dict, "%s not found in bond_vocab_dict" % name |
|
return len(CompoundKit.bond_vocab_dict[name]) |
|
|
|
|
|
|
|
@staticmethod |
|
def get_morgan_fingerprint(mol, radius=2): |
|
"""get morgan fingerprint""" |
|
nBits = CompoundKit.morgan_fp_N |
|
mfp = AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=nBits) |
|
return [int(b) for b in mfp.ToBitString()] |
|
|
|
@staticmethod |
|
def get_morgan2048_fingerprint(mol, radius=2): |
|
"""get morgan2048 fingerprint""" |
|
nBits = CompoundKit.morgan2048_fp_N |
|
mfp = AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=nBits) |
|
return [int(b) for b in mfp.ToBitString()] |
|
|
|
@staticmethod |
|
def get_maccs_fingerprint(mol): |
|
"""get maccs fingerprint""" |
|
fp = AllChem.GetMACCSKeysFingerprint(mol) |
|
return [int(b) for b in fp.ToBitString()] |
|
|
|
|
|
|
|
@staticmethod |
|
def get_daylight_functional_group_counts(mol): |
|
"""get daylight functional group counts""" |
|
fg_counts = [] |
|
for fg_mol in CompoundKit.day_light_fg_mo_list: |
|
sub_structs = Chem.Mol.GetSubstructMatches(mol, fg_mol, uniquify=True) |
|
fg_counts.append(len(sub_structs)) |
|
return fg_counts |
|
|
|
@staticmethod |
|
def get_ring_size(mol): |
|
"""return (N,6) list""" |
|
rings = mol.GetRingInfo() |
|
rings_info = [] |
|
for r in rings.AtomRings(): |
|
rings_info.append(r) |
|
ring_list = [] |
|
for atom in mol.GetAtoms(): |
|
atom_result = [] |
|
for ringsize in range(3, 9): |
|
num_of_ring_at_ringsize = 0 |
|
for r in rings_info: |
|
if len(r) == ringsize and atom.GetIdx() in r: |
|
num_of_ring_at_ringsize += 1 |
|
if num_of_ring_at_ringsize > 8: |
|
num_of_ring_at_ringsize = 9 |
|
atom_result.append(num_of_ring_at_ringsize) |
|
|
|
ring_list.append(atom_result) |
|
return ring_list |
|
|
|
@staticmethod |
|
def atom_to_feat_vector(atom): |
|
""" tbd """ |
|
atom_names = { |
|
"atomic_num": safe_index(CompoundKit.atom_vocab_dict["atomic_num"], atom.GetAtomicNum()), |
|
"chiral_tag": safe_index(CompoundKit.atom_vocab_dict["chiral_tag"], atom.GetChiralTag()), |
|
"degree": safe_index(CompoundKit.atom_vocab_dict["degree"], atom.GetTotalDegree()), |
|
"explicit_valence": safe_index(CompoundKit.atom_vocab_dict["explicit_valence"], atom.GetExplicitValence()), |
|
"formal_charge": safe_index(CompoundKit.atom_vocab_dict["formal_charge"], atom.GetFormalCharge()), |
|
"hybridization": safe_index(CompoundKit.atom_vocab_dict["hybridization"], atom.GetHybridization()), |
|
"implicit_valence": safe_index(CompoundKit.atom_vocab_dict["implicit_valence"], atom.GetImplicitValence()), |
|
"is_aromatic": safe_index(CompoundKit.atom_vocab_dict["is_aromatic"], int(atom.GetIsAromatic())), |
|
"total_numHs": safe_index(CompoundKit.atom_vocab_dict["total_numHs"], atom.GetTotalNumHs()), |
|
'num_radical_e': safe_index(CompoundKit.atom_vocab_dict['num_radical_e'], atom.GetNumRadicalElectrons()), |
|
'atom_is_in_ring': safe_index(CompoundKit.atom_vocab_dict['atom_is_in_ring'], int(atom.IsInRing())), |
|
'valence_out_shell': safe_index(CompoundKit.atom_vocab_dict['valence_out_shell'], |
|
CompoundKit.period_table.GetNOuterElecs(atom.GetAtomicNum())), |
|
'van_der_waals_radis': CompoundKit.period_table.GetRvdw(atom.GetAtomicNum()), |
|
'partial_charge': CompoundKit.check_partial_charge(atom), |
|
'mass': atom.GetMass(), |
|
} |
|
return atom_names |
|
|
|
@staticmethod |
|
def get_atom_names(mol): |
|
"""get atom name list |
|
TODO: to be remove in the future |
|
""" |
|
atom_features_dicts = [] |
|
Chem.rdPartialCharges.ComputeGasteigerCharges(mol) |
|
for i, atom in enumerate(mol.GetAtoms()): |
|
atom_features_dicts.append(CompoundKit.atom_to_feat_vector(atom)) |
|
|
|
ring_list = CompoundKit.get_ring_size(mol) |
|
for i, atom in enumerate(mol.GetAtoms()): |
|
atom_features_dicts[i]['in_num_ring_with_size3'] = safe_index( |
|
CompoundKit.atom_vocab_dict['in_num_ring_with_size3'], ring_list[i][0]) |
|
atom_features_dicts[i]['in_num_ring_with_size4'] = safe_index( |
|
CompoundKit.atom_vocab_dict['in_num_ring_with_size4'], ring_list[i][1]) |
|
atom_features_dicts[i]['in_num_ring_with_size5'] = safe_index( |
|
CompoundKit.atom_vocab_dict['in_num_ring_with_size5'], ring_list[i][2]) |
|
atom_features_dicts[i]['in_num_ring_with_size6'] = safe_index( |
|
CompoundKit.atom_vocab_dict['in_num_ring_with_size6'], ring_list[i][3]) |
|
atom_features_dicts[i]['in_num_ring_with_size7'] = safe_index( |
|
CompoundKit.atom_vocab_dict['in_num_ring_with_size7'], ring_list[i][4]) |
|
atom_features_dicts[i]['in_num_ring_with_size8'] = safe_index( |
|
CompoundKit.atom_vocab_dict['in_num_ring_with_size8'], ring_list[i][5]) |
|
|
|
return atom_features_dicts |
|
|
|
@staticmethod |
|
def check_partial_charge(atom): |
|
"""tbd""" |
|
pc = atom.GetDoubleProp('_GasteigerCharge') |
|
if pc != pc: |
|
|
|
pc = 0 |
|
if pc == float('inf'): |
|
|
|
pc = 10 |
|
return pc |
|
|
|
|
|
class Compound3DKit(object): |
|
"""the 3Dkit of Compound""" |
|
|
|
@staticmethod |
|
def get_atom_poses(mol, conf): |
|
"""tbd""" |
|
atom_poses = [] |
|
for i, atom in enumerate(mol.GetAtoms()): |
|
if atom.GetAtomicNum() == 0: |
|
return [[0.0, 0.0, 0.0]] * len(mol.GetAtoms()) |
|
pos = conf.GetAtomPosition(i) |
|
atom_poses.append([pos.x, pos.y, pos.z]) |
|
return atom_poses |
|
|
|
@staticmethod |
|
def get_MMFF_atom_poses(mol, numConfs=None, return_energy=False): |
|
"""the atoms of mol will be changed in some cases.""" |
|
conf = mol.GetConformer() |
|
atom_poses = Compound3DKit.get_atom_poses(mol, conf) |
|
return mol,atom_poses |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod |
|
def get_2d_atom_poses(mol): |
|
"""get 2d atom poses""" |
|
AllChem.Compute2DCoords(mol) |
|
conf = mol.GetConformer() |
|
atom_poses = Compound3DKit.get_atom_poses(mol, conf) |
|
return atom_poses |
|
|
|
@staticmethod |
|
def get_bond_lengths(edges, atom_poses): |
|
"""get bond lengths""" |
|
bond_lengths = [] |
|
for src_node_i, tar_node_j in edges: |
|
bond_lengths.append(np.linalg.norm(atom_poses[tar_node_j] - atom_poses[src_node_i])) |
|
bond_lengths = np.array(bond_lengths, 'float32') |
|
return bond_lengths |
|
|
|
@staticmethod |
|
def get_superedge_angles(edges, atom_poses, dir_type='HT'): |
|
"""get superedge angles""" |
|
|
|
def _get_vec(atom_poses, edge): |
|
return atom_poses[edge[1]] - atom_poses[edge[0]] |
|
|
|
def _get_angle(vec1, vec2): |
|
norm1 = np.linalg.norm(vec1) |
|
norm2 = np.linalg.norm(vec2) |
|
if norm1 == 0 or norm2 == 0: |
|
return 0 |
|
vec1 = vec1 / (norm1 + 1e-5) |
|
vec2 = vec2 / (norm2 + 1e-5) |
|
angle = np.arccos(np.dot(vec1, vec2)) |
|
return angle |
|
|
|
E = len(edges) |
|
edge_indices = np.arange(E) |
|
super_edges = [] |
|
bond_angles = [] |
|
bond_angle_dirs = [] |
|
for tar_edge_i in range(E): |
|
tar_edge = edges[tar_edge_i] |
|
if dir_type == 'HT': |
|
src_edge_indices = edge_indices[edges[:, 1] == tar_edge[0]] |
|
elif dir_type == 'HH': |
|
src_edge_indices = edge_indices[edges[:, 1] == tar_edge[1]] |
|
else: |
|
raise ValueError(dir_type) |
|
for src_edge_i in src_edge_indices: |
|
if src_edge_i == tar_edge_i: |
|
continue |
|
src_edge = edges[src_edge_i] |
|
src_vec = _get_vec(atom_poses, src_edge) |
|
tar_vec = _get_vec(atom_poses, tar_edge) |
|
super_edges.append([src_edge_i, tar_edge_i]) |
|
angle = _get_angle(src_vec, tar_vec) |
|
bond_angles.append(angle) |
|
bond_angle_dirs.append(src_edge[1] == tar_edge[0]) |
|
|
|
if len(super_edges) == 0: |
|
super_edges = np.zeros([0, 2], 'int64') |
|
bond_angles = np.zeros([0, ], 'float32') |
|
else: |
|
super_edges = np.array(super_edges, 'int64') |
|
bond_angles = np.array(bond_angles, 'float32') |
|
return super_edges, bond_angles, bond_angle_dirs |
|
|
|
|
|
def new_smiles_to_graph_data(smiles, **kwargs): |
|
""" |
|
Convert smiles to graph data. |
|
""" |
|
mol = AllChem.MolFromSmiles(smiles) |
|
if mol is None: |
|
return None |
|
data = new_mol_to_graph_data(mol) |
|
return data |
|
|
|
|
|
def new_mol_to_graph_data(mol): |
|
""" |
|
mol_to_graph_data |
|
Args: |
|
atom_features: Atom features. |
|
edge_features: Edge features. |
|
morgan_fingerprint: Morgan fingerprint. |
|
functional_groups: Functional groups. |
|
""" |
|
if len(mol.GetAtoms()) == 0: |
|
return None |
|
|
|
atom_id_names = list(CompoundKit.atom_vocab_dict.keys()) + CompoundKit.atom_float_names |
|
bond_id_names = list(CompoundKit.bond_vocab_dict.keys()) |
|
|
|
data = {} |
|
|
|
|
|
data = {name: [] for name in atom_id_names} |
|
|
|
raw_atom_feat_dicts = CompoundKit.get_atom_names(mol) |
|
for atom_feat in raw_atom_feat_dicts: |
|
for name in atom_id_names: |
|
data[name].append(atom_feat[name]) |
|
|
|
|
|
for name in bond_id_names: |
|
data[name] = [] |
|
data['edges'] = [] |
|
|
|
for bond in mol.GetBonds(): |
|
i = bond.GetBeginAtomIdx() |
|
j = bond.GetEndAtomIdx() |
|
|
|
data['edges'] += [(i, j), (j, i)] |
|
for name in bond_id_names: |
|
bond_feature_id = CompoundKit.get_bond_feature_id(bond, name) |
|
data[name] += [bond_feature_id] * 2 |
|
|
|
|
|
N = len(data[atom_id_names[0]]) |
|
for i in range(N): |
|
data['edges'] += [(i, i)] |
|
for name in bond_id_names: |
|
bond_feature_id = get_bond_feature_dims([name])[0] - 1 |
|
data[name] += [bond_feature_id] * N |
|
|
|
|
|
for name in list(CompoundKit.atom_vocab_dict.keys()): |
|
data[name] = np.array(data[name], 'int64') |
|
for name in CompoundKit.atom_float_names: |
|
data[name] = np.array(data[name], 'float32') |
|
for name in bond_id_names: |
|
data[name] = np.array(data[name], 'int64') |
|
data['edges'] = np.array(data['edges'], 'int64') |
|
|
|
|
|
data['morgan_fp'] = np.array(CompoundKit.get_morgan_fingerprint(mol), 'int64') |
|
|
|
data['maccs_fp'] = np.array(CompoundKit.get_maccs_fingerprint(mol), 'int64') |
|
data['daylight_fg_counts'] = np.array(CompoundKit.get_daylight_functional_group_counts(mol), 'int64') |
|
return data |
|
|
|
|
|
def mol_to_graph_data(mol): |
|
""" |
|
mol_to_graph_data |
|
Args: |
|
atom_features: Atom features. |
|
edge_features: Edge features. |
|
morgan_fingerprint: Morgan fingerprint. |
|
functional_groups: Functional groups. |
|
""" |
|
if len(mol.GetAtoms()) == 0: |
|
return None |
|
|
|
atom_id_names = [ |
|
"atomic_num", "chiral_tag", "degree", "explicit_valence", |
|
"formal_charge", "hybridization", "implicit_valence", |
|
"is_aromatic", "total_numHs", |
|
] |
|
bond_id_names = [ |
|
"bond_dir", "bond_type", "is_in_ring", |
|
] |
|
|
|
data = {} |
|
for name in atom_id_names: |
|
data[name] = [] |
|
data['mass'] = [] |
|
for name in bond_id_names: |
|
data[name] = [] |
|
data['edges'] = [] |
|
|
|
|
|
for i, atom in enumerate(mol.GetAtoms()): |
|
if atom.GetAtomicNum() == 0: |
|
return None |
|
for name in atom_id_names: |
|
data[name].append(CompoundKit.get_atom_feature_id(atom, name) + 1) |
|
data['mass'].append(CompoundKit.get_atom_value(atom, 'mass') * 0.01) |
|
|
|
|
|
for bond in mol.GetBonds(): |
|
i = bond.GetBeginAtomIdx() |
|
j = bond.GetEndAtomIdx() |
|
|
|
data['edges'] += [(i, j), (j, i)] |
|
for name in bond_id_names: |
|
bond_feature_id = CompoundKit.get_bond_feature_id(bond, name) + 1 |
|
data[name] += [bond_feature_id] * 2 |
|
|
|
|
|
N = len(data[atom_id_names[0]]) |
|
for i in range(N): |
|
data['edges'] += [(i, i)] |
|
for name in bond_id_names: |
|
bond_feature_id = CompoundKit.get_bond_feature_size(name) + 2 |
|
data[name] += [bond_feature_id] * N |
|
|
|
|
|
if len(data['edges']) == 0: |
|
for name in bond_id_names: |
|
data[name] = np.zeros((0,), dtype="int64") |
|
data['edges'] = np.zeros((0, 2), dtype="int64") |
|
|
|
|
|
for name in atom_id_names: |
|
data[name] = np.array(data[name], 'int64') |
|
data['mass'] = np.array(data['mass'], 'float32') |
|
for name in bond_id_names: |
|
data[name] = np.array(data[name], 'int64') |
|
data['edges'] = np.array(data['edges'], 'int64') |
|
|
|
|
|
data['morgan_fp'] = np.array(CompoundKit.get_morgan_fingerprint(mol), 'int64') |
|
|
|
data['maccs_fp'] = np.array(CompoundKit.get_maccs_fingerprint(mol), 'int64') |
|
data['daylight_fg_counts'] = np.array(CompoundKit.get_daylight_functional_group_counts(mol), 'int64') |
|
return data |
|
|
|
|
|
def mol_to_geognn_graph_data(mol, atom_poses, dir_type): |
|
""" |
|
mol: rdkit molecule |
|
dir_type: direction type for bond_angle grpah |
|
""" |
|
if len(mol.GetAtoms()) == 0: |
|
return None |
|
|
|
data = mol_to_graph_data(mol) |
|
|
|
data['atom_pos'] = np.array(atom_poses, 'float32') |
|
data['bond_length'] = Compound3DKit.get_bond_lengths(data['edges'], data['atom_pos']) |
|
BondAngleGraph_edges, bond_angles, bond_angle_dirs = \ |
|
Compound3DKit.get_superedge_angles(data['edges'], data['atom_pos']) |
|
data['BondAngleGraph_edges'] = BondAngleGraph_edges |
|
data['bond_angle'] = np.array(bond_angles, 'float32') |
|
return data |
|
|
|
|
|
def mol_to_geognn_graph_data_MMFF3d(mol): |
|
"""tbd""" |
|
if len(mol.GetAtoms()) <= 400: |
|
mol, atom_poses = Compound3DKit.get_MMFF_atom_poses(mol, numConfs=10) |
|
else: |
|
atom_poses = Compound3DKit.get_2d_atom_poses(mol) |
|
return mol_to_geognn_graph_data(mol, atom_poses, dir_type='HT') |
|
|
|
|
|
def mol_to_geognn_graph_data_raw3d(mol): |
|
"""tbd""" |
|
atom_poses = Compound3DKit.get_atom_poses(mol, mol.GetConformer()) |
|
return mol_to_geognn_graph_data(mol, atom_poses, dir_type='HT') |
|
|
|
def obtain_3D_mol(smiles,name): |
|
mol = AllChem.MolFromSmiles(smiles) |
|
new_mol = Chem.AddHs(mol) |
|
res = AllChem.EmbedMultipleConfs(new_mol) |
|
|
|
res = AllChem.MMFFOptimizeMoleculeConfs(new_mol) |
|
new_mol = Chem.RemoveHs(new_mol) |
|
Chem.MolToMolFile(new_mol, name+'.mol') |
|
return new_mol |
|
|
|
os.environ['CUDA_LAUNCH_BLOCKING'] = '1' |
|
warnings.filterwarnings('ignore') |
|
|
|
|
|
MODEL = 'Test' |
|
test_mode='fixed' |
|
transfer_target='All_column' |
|
Use_geometry_enhanced=True |
|
Use_column_info=True |
|
|
|
atom_id_names = [ |
|
"atomic_num", "chiral_tag", "degree", "explicit_valence", |
|
"formal_charge", "hybridization", "implicit_valence", |
|
"is_aromatic", "total_numHs", |
|
] |
|
bond_id_names = [ |
|
"bond_dir", "bond_type", "is_in_ring"] |
|
|
|
if Use_geometry_enhanced==True: |
|
bond_float_names = ["bond_length",'prop'] |
|
|
|
if Use_geometry_enhanced==False: |
|
bond_float_names=['prop'] |
|
|
|
bond_angle_float_names = ['bond_angle', 'TPSA', 'RASA', 'RPSA', 'MDEC', 'MATS'] |
|
|
|
column_specify={'ADH':[1,5,0,0],'ODH':[1,5,0,1],'IC':[0,5,1,2],'IA':[0,5,1,3],'OJH':[1,5,0,4], |
|
'ASH':[1,5,0,5],'IC3':[0,3,1,6],'IE':[0,5,1,7],'ID':[0,5,1,8],'OD3':[1,3,0,9], |
|
'IB':[0,5,1,10],'AD':[1,10,0,11],'AD3':[1,3,0,12],'IF':[0,5,1,13],'OD':[1,10,0,14], |
|
'AS':[1,10,0,15],'OJ3':[1,3,0,16],'IG':[0,5,1,17],'AZ':[1,10,0,18],'IAH':[0,5,1,19], |
|
'OJ':[1,10,0,20],'ICH':[0,5,1,21],'OZ3':[1,3,0,22],'IF3':[0,3,1,23],'IAU':[0,1.6,1,24]} |
|
column_smile=['O=C(OC[C@@H](O1)[C@@H](OC)[C@H](OC(NC2=CC(C)=CC(C)=C2)=O)[C@@H](OC(NC3=CC(C)=CC(C)=C3)=O)[C@H]1OC)NC4=CC(C)=CC(C)=C4', |
|
'O=C(OC[C@@H](O1)[C@@H](OC)[C@H](OC(NC2=CC(C)=CC(C)=C2)=O)[C@@H](OC(NC3=CC(C)=CC(C)=C3)=O)[C@@H]1OC)NC4=CC(C)=CC(C)=C4', |
|
'O=C(OC[C@@H](O1)[C@@H](OC)[C@H](OC(NC2=CC(Cl)=CC(Cl)=C2)=O)[C@@H](OC(NC3=CC(Cl)=CC(Cl)=C3)=O)[C@@H]1OC)NC4=CC(Cl)=CC(Cl)=C4', |
|
'O=C(OC[C@@H](O1)[C@@H](OC)[C@H](OC(NC2=CC(C)=CC(C)=C2)=O)[C@@H](OC(NC3=CC(C)=CC(C)=C3)=O)[C@H]1OC)NC4=CC(C)=CC(C)=C4', |
|
'O=C(OC[C@@H](O1)[C@@H](OC)[C@H](OC(C2=CC=C(C)C=C2)=O)[C@@H](OC(C3=CC=C(C)C=C3)=O)[C@@H]1OC)C4=CC=C(C)C=C4', |
|
'O=C(OC[C@@H](O1)[C@@H](OC)[C@H](OC(N[C@@H](C)C2=CC=CC=C2)=O)[C@@H](OC(N[C@@H](C)C3=CC=CC=C3)=O)[C@H]1OC)N[C@@H](C)C4=CC=CC=C4', |
|
'O=C(OC[C@@H](O1)[C@@H](OC)[C@H](OC(NC2=CC(Cl)=CC(Cl)=C2)=O)[C@@H](OC(NC3=CC(Cl)=CC(Cl)=C3)=O)[C@@H]1OC)NC4=CC(Cl)=CC(Cl)=C4', |
|
'O=C(OC[C@@H](O1)[C@@H](OC)[C@H](OC(NC2=CC(Cl)=CC(Cl)=C2)=O)[C@@H](OC(NC3=CC(Cl)=CC(Cl)=C3)=O)[C@H]1OC)NC4=CC(Cl)=CC(Cl)=C4', |
|
'O=C(OC[C@@H](O1)[C@@H](OC)[C@H](OC(NC2=CC=CC(Cl)=C2)=O)[C@@H](OC(NC3=CC=CC(Cl)=C3)=O)[C@H]1OC)NC4=CC=CC(Cl)=C4', |
|
'O=C(OC[C@@H](O1)[C@@H](OC)[C@H](OC(NC2=CC(C)=CC(C)=C2)=O)[C@@H](OC(NC3=CC(C)=CC(C)=C3)=O)[C@@H]1OC)NC4=CC(C)=CC(C)=C4', |
|
'O=C(OC[C@@H](O1)[C@@H](OC)[C@H](OC(NC2=CC(C)=CC(C)=C2)=O)[C@@H](OC(NC3=CC(C)=CC(C)=C3)=O)[C@@H]1OC)NC4=CC(C)=CC(C)=C4', |
|
'O=C(OC[C@@H](O1)[C@@H](OC)[C@H](OC(NC2=CC(C)=CC(C)=C2)=O)[C@@H](OC(NC3=CC(C)=CC(C)=C3)=O)[C@H]1OC)NC4=CC(C)=CC(C)=C4', |
|
'O=C(OC[C@@H](O1)[C@@H](OC)[C@H](OC(NC2=CC(C)=CC(C)=C2)=O)[C@@H](OC(NC3=CC(C)=CC(C)=C3)=O)[C@H]1OC)NC4=CC(C)=CC(C)=C4', |
|
'O=C(OC[C@@H](O1)[C@@H](OC)[C@H](OC(NC2=CC=C(C)C(Cl)=C2)=O)[C@@H](OC(NC3=CC=C(C)C(Cl)=C3)=O)[C@H]1OC)NC4=CC=C(C)C(Cl)=C4', |
|
'O=C(OC[C@@H](O1)[C@@H](OC)[C@H](OC(NC2=CC(C)=CC(C)=C2)=O)[C@@H](OC(NC3=CC(C)=CC(C)=C3)=O)[C@@H]1OC)NC4=CC(C)=CC(C)=C4', |
|
'O=C(OC[C@@H](O1)[C@@H](OC)[C@H](OC(N[C@@H](C)C2=CC=CC=C2)=O)[C@@H](OC(N[C@@H](C)C3=CC=CC=C3)=O)[C@H]1OC)N[C@@H](C)C4=CC=CC=C4', |
|
'O=C(OC[C@@H](O1)[C@@H](OC)[C@H](OC(C2=CC=C(C)C=C2)=O)[C@@H](OC(C3=CC=C(C)C=C3)=O)[C@@H]1OC)C4=CC=C(C)C=C4', |
|
'O=C(OC[C@@H](O1)[C@@H](OC)[C@H](OC(NC2=CC(C)=CC(Cl)=C2)=O)[C@@H](OC(NC3=CC(C)=CC(Cl)=C3)=O)[C@H]1OC)NC4=CC(C)=CC(Cl)=C4', |
|
'O=C(OC[C@@H](O1)[C@@H](OC)[C@H](OC(NC2=CC=C(C)C(Cl)=C2)=O)[C@@H](OC(NC3=CC=C(C)C(Cl)=C3)=O)[C@H]1OC)NC4=CC=C(C)C(Cl)=C4', |
|
'O=C(OC[C@@H](O1)[C@@H](OC)[C@H](OC(NC2=CC(C)=CC(C)=C2)=O)[C@@H](OC(NC3=CC(C)=CC(C)=C3)=O)[C@H]1OC)NC4=CC(C)=CC(C)=C4', |
|
'O=C(OC[C@@H](O1)[C@@H](OC)[C@H](OC(C2=CC=C(C)C=C2)=O)[C@@H](OC(C3=CC=C(C)C=C3)=O)[C@@H]1OC)C4=CC=C(C)C=C4', |
|
'O=C(OC[C@@H](O1)[C@@H](OC)[C@H](OC(NC2=CC(Cl)=CC(Cl)=C2)=O)[C@@H](OC(NC3=CC(Cl)=CC(Cl)=C3)=O)[C@@H]1OC)NC4=CC(Cl)=CC(Cl)=C4', |
|
'O=C(OC[C@@H](O1)[C@@H](OC)[C@H](OC(NC2=CC=C(C)C(Cl)=C2)=O)[C@@H](OC(NC3=CC=C(C)C(Cl)=C3)=O)[C@@H]1OC)NC4=CC=C(C)C(Cl)=C4', |
|
'O=C(OC[C@@H](O1)[C@@H](OC)[C@H](OC(NC2=CC=C(C)C(Cl)=C2)=O)[C@@H](OC(NC3=CC=C(C)C(Cl)=C3)=O)[C@H]1OC)NC4=CC=C(C)C(Cl)=C4', |
|
'O=C(OC[C@@H](O1)[C@@H](OC)[C@H](OC(NC2=CC(C)=CC(C)=C2)=O)[C@@H](OC(NC3=CC(C)=CC(C)=C3)=O)[C@H]1OC)NC4=CC(C)=CC(C)=C4'] |
|
column_name=['ADH','ODH','IC','IA','OJH','ASH','IC3','IE','ID','OD3', 'IB','AD','AD3', |
|
'IF','OD','AS','OJ3','IG','AZ','IAH','OJ','ICH','OZ3','IF3','IAU'] |
|
full_atom_feature_dims = get_atom_feature_dims(atom_id_names) |
|
full_bond_feature_dims = get_bond_feature_dims(bond_id_names) |
|
|
|
|
|
if Use_column_info==True: |
|
bond_id_names.extend(['coated', 'immobilized']) |
|
bond_float_names.extend(['diameter']) |
|
if Use_geometry_enhanced==True: |
|
bond_angle_float_names.extend(['column_TPSA', 'column_TPSA', 'column_TPSA', 'column_MDEC', 'column_MATS']) |
|
else: |
|
bond_float_names.extend(['column_TPSA', 'column_TPSA', 'column_TPSA', 'column_MDEC', 'column_MATS']) |
|
full_bond_feature_dims.extend([2,2]) |
|
|
|
calc = Calculator(descriptors, ignore_3D=False) |
|
|
|
|
|
class AtomEncoder(torch.nn.Module): |
|
|
|
def __init__(self, emb_dim): |
|
super(AtomEncoder, self).__init__() |
|
|
|
self.atom_embedding_list = torch.nn.ModuleList() |
|
|
|
for i, dim in enumerate(full_atom_feature_dims): |
|
emb = torch.nn.Embedding(dim + 5, emb_dim) |
|
torch.nn.init.xavier_uniform_(emb.weight.data) |
|
self.atom_embedding_list.append(emb) |
|
|
|
def forward(self, x): |
|
x_embedding = 0 |
|
for i in range(x.shape[1]): |
|
x_embedding += self.atom_embedding_list[i](x[:, i]) |
|
|
|
return x_embedding |
|
|
|
class BondEncoder(torch.nn.Module): |
|
|
|
def __init__(self, emb_dim): |
|
super(BondEncoder, self).__init__() |
|
|
|
self.bond_embedding_list = torch.nn.ModuleList() |
|
|
|
for i, dim in enumerate(full_bond_feature_dims): |
|
emb = torch.nn.Embedding(dim + 5, emb_dim) |
|
torch.nn.init.xavier_uniform_(emb.weight.data) |
|
self.bond_embedding_list.append(emb) |
|
|
|
def forward(self, edge_attr): |
|
bond_embedding = 0 |
|
for i in range(edge_attr.shape[1]): |
|
bond_embedding += self.bond_embedding_list[i](edge_attr[:, i]) |
|
|
|
return bond_embedding |
|
|
|
class RBF(torch.nn.Module): |
|
""" |
|
Radial Basis Function |
|
""" |
|
|
|
def __init__(self, centers, gamma, dtype='float32'): |
|
super(RBF, self).__init__() |
|
self.centers = centers.reshape([1, -1]) |
|
self.gamma = gamma |
|
|
|
def forward(self, x): |
|
""" |
|
Args: |
|
x(tensor): (-1, 1). |
|
Returns: |
|
y(tensor): (-1, n_centers) |
|
""" |
|
x = x.reshape([-1, 1]) |
|
return torch.exp(-self.gamma * torch.square(x - self.centers)) |
|
|
|
class BondFloatRBF(torch.nn.Module): |
|
""" |
|
Bond Float Encoder using Radial Basis Functions |
|
""" |
|
|
|
def __init__(self, bond_float_names, embed_dim, rbf_params=None): |
|
super(BondFloatRBF, self).__init__() |
|
self.bond_float_names = bond_float_names |
|
|
|
if rbf_params is None: |
|
self.rbf_params = { |
|
'bond_length': (nn.Parameter(torch.arange(0, 2, 0.1)), nn.Parameter(torch.Tensor([10.0]))), |
|
|
|
'prop': (nn.Parameter(torch.arange(0, 1, 0.05)), nn.Parameter(torch.Tensor([1.0]))), |
|
'diameter': (nn.Parameter(torch.arange(3, 12, 0.3)), nn.Parameter(torch.Tensor([1.0]))), |
|
|
|
'column_TPSA': (nn.Parameter(torch.arange(0, 1, 0.05).to(torch.float32)), nn.Parameter(torch.Tensor([1.0]))), |
|
'column_RASA': (nn.Parameter(torch.arange(0, 1, 0.05)), nn.Parameter(torch.Tensor([1.0]))), |
|
'column_RPSA': (nn.Parameter(torch.arange(0, 1, 0.05)), nn.Parameter(torch.Tensor([1.0]))), |
|
'column_MDEC': (nn.Parameter(torch.arange(0, 10, 0.5)), nn.Parameter(torch.Tensor([2.0]))), |
|
'column_MATS': (nn.Parameter(torch.arange(0, 1, 0.05)), nn.Parameter(torch.Tensor([1.0]))), |
|
} |
|
else: |
|
self.rbf_params = rbf_params |
|
|
|
self.linear_list = torch.nn.ModuleList() |
|
self.rbf_list = torch.nn.ModuleList() |
|
for name in self.bond_float_names: |
|
centers, gamma = self.rbf_params[name] |
|
rbf = RBF(centers.to(device), gamma.to(device)) |
|
self.rbf_list.append(rbf) |
|
linear = torch.nn.Linear(len(centers), embed_dim).to(device) |
|
self.linear_list.append(linear) |
|
|
|
def forward(self, bond_float_features): |
|
""" |
|
Args: |
|
bond_float_features(dict of tensor): bond float features. |
|
""" |
|
out_embed = 0 |
|
for i, name in enumerate(self.bond_float_names): |
|
x = bond_float_features[:, i].reshape(-1, 1) |
|
rbf_x = self.rbf_list[i](x) |
|
out_embed += self.linear_list[i](rbf_x) |
|
return out_embed |
|
|
|
class BondAngleFloatRBF(torch.nn.Module): |
|
""" |
|
Bond Angle Float Encoder using Radial Basis Functions |
|
""" |
|
|
|
def __init__(self, bond_angle_float_names, embed_dim, rbf_params=None): |
|
super(BondAngleFloatRBF, self).__init__() |
|
self.bond_angle_float_names = bond_angle_float_names |
|
|
|
if rbf_params is None: |
|
self.rbf_params = { |
|
'bond_angle': (nn.Parameter(torch.arange(0, torch.pi, 0.1)), nn.Parameter(torch.Tensor([10.0]))), |
|
} |
|
else: |
|
self.rbf_params = rbf_params |
|
|
|
self.linear_list = torch.nn.ModuleList() |
|
self.rbf_list = torch.nn.ModuleList() |
|
for name in self.bond_angle_float_names: |
|
if name == 'bond_angle': |
|
centers, gamma = self.rbf_params[name] |
|
rbf = RBF(centers.to(device), gamma.to(device)) |
|
self.rbf_list.append(rbf) |
|
linear = nn.Linear(len(centers), embed_dim) |
|
self.linear_list.append(linear) |
|
else: |
|
linear = nn.Linear(len(self.bond_angle_float_names) - 1, embed_dim) |
|
self.linear_list.append(linear) |
|
break |
|
|
|
def forward(self, bond_angle_float_features): |
|
""" |
|
Args: |
|
bond_angle_float_features(dict of tensor): bond angle float features. |
|
""" |
|
out_embed = 0 |
|
for i, name in enumerate(self.bond_angle_float_names): |
|
if name == 'bond_angle': |
|
x = bond_angle_float_features[:, i].reshape(-1, 1) |
|
rbf_x = self.rbf_list[i](x) |
|
out_embed += self.linear_list[i](rbf_x) |
|
else: |
|
x = bond_angle_float_features[:, 1:] |
|
out_embed += self.linear_list[i](x) |
|
break |
|
return out_embed |
|
|
|
class GINConv(MessagePassing): |
|
def __init__(self, emb_dim): |
|
''' |
|
emb_dim (int): node embedding dimensionality |
|
''' |
|
|
|
super(GINConv, self).__init__(aggr="add") |
|
|
|
self.mlp = nn.Sequential(nn.Linear(emb_dim, emb_dim), nn.BatchNorm1d(emb_dim), nn.ReLU(), |
|
nn.Linear(emb_dim, emb_dim)) |
|
self.eps = nn.Parameter(torch.Tensor([0])) |
|
|
|
def forward(self, x, edge_index, edge_attr): |
|
edge_embedding = edge_attr |
|
out = self.mlp((1 + self.eps) * x + self.propagate(edge_index, x=x, edge_attr=edge_embedding)) |
|
return out |
|
|
|
def message(self, x_j, edge_attr): |
|
return F.relu(x_j + edge_attr) |
|
|
|
def update(self, aggr_out): |
|
return aggr_out |
|
|
|
|
|
class GINNodeEmbedding(torch.nn.Module): |
|
""" |
|
Output: |
|
node representations |
|
""" |
|
|
|
def __init__(self, num_layers, emb_dim, drop_ratio=0.5, JK="last", residual=False): |
|
"""GIN Node Embedding Module |
|
采用多层GINConv实现图上结点的嵌入。 |
|
""" |
|
|
|
super(GINNodeEmbedding, self).__init__() |
|
self.num_layers = num_layers |
|
self.drop_ratio = drop_ratio |
|
self.JK = JK |
|
|
|
self.residual = residual |
|
|
|
if self.num_layers < 2: |
|
raise ValueError("Number of GNN layers must be greater than 1.") |
|
|
|
self.atom_encoder = AtomEncoder(emb_dim) |
|
self.bond_encoder=BondEncoder(emb_dim) |
|
self.bond_float_encoder=BondFloatRBF(bond_float_names,emb_dim) |
|
self.bond_angle_encoder=BondAngleFloatRBF(bond_angle_float_names,emb_dim) |
|
|
|
|
|
self.convs = torch.nn.ModuleList() |
|
self.convs_bond_angle=torch.nn.ModuleList() |
|
self.convs_bond_float=torch.nn.ModuleList() |
|
self.convs_bond_embeding=torch.nn.ModuleList() |
|
self.convs_angle_float=torch.nn.ModuleList() |
|
self.batch_norms = torch.nn.ModuleList() |
|
self.batch_norms_ba = torch.nn.ModuleList() |
|
for layer in range(num_layers): |
|
self.convs.append(GINConv(emb_dim)) |
|
self.convs_bond_angle.append(GINConv(emb_dim)) |
|
self.convs_bond_embeding.append(BondEncoder(emb_dim)) |
|
self.convs_bond_float.append(BondFloatRBF(bond_float_names,emb_dim)) |
|
self.convs_angle_float.append(BondAngleFloatRBF(bond_angle_float_names,emb_dim)) |
|
self.batch_norms.append(torch.nn.BatchNorm1d(emb_dim)) |
|
self.batch_norms_ba.append(torch.nn.BatchNorm1d(emb_dim)) |
|
|
|
def forward(self, batched_atom_bond,batched_bond_angle): |
|
x, edge_index, edge_attr = batched_atom_bond.x, batched_atom_bond.edge_index, batched_atom_bond.edge_attr |
|
edge_index_ba,edge_attr_ba= batched_bond_angle.edge_index, batched_bond_angle.edge_attr |
|
|
|
h_list = [self.atom_encoder(x)] |
|
|
|
if Use_geometry_enhanced==True: |
|
h_list_ba=[self.bond_float_encoder(edge_attr[:,len(bond_id_names):edge_attr.shape[1]+1].to(torch.float32))+self.bond_encoder(edge_attr[:,0:len(bond_id_names)].to(torch.int64))] |
|
for layer in range(self.num_layers): |
|
h = self.convs[layer](h_list[layer], edge_index, h_list_ba[layer]) |
|
cur_h_ba=self.convs_bond_embeding[layer](edge_attr[:,0:len(bond_id_names)].to(torch.int64))+self.convs_bond_float[layer](edge_attr[:,len(bond_id_names):edge_attr.shape[1]+1].to(torch.float32)) |
|
cur_angle_hidden=self.convs_angle_float[layer](edge_attr_ba) |
|
h_ba=self.convs_bond_angle[layer](cur_h_ba, edge_index_ba, cur_angle_hidden) |
|
|
|
if layer == self.num_layers - 1: |
|
|
|
h = F.dropout(h, self.drop_ratio, training=self.training) |
|
h_ba = F.dropout(h_ba, self.drop_ratio, training=self.training) |
|
else: |
|
h = F.dropout(F.relu(h), self.drop_ratio, training=self.training) |
|
h_ba = F.dropout(F.relu(h_ba), self.drop_ratio, training=self.training) |
|
if self.residual: |
|
h += h_list[layer] |
|
h_ba+=h_list_ba[layer] |
|
h_list.append(h) |
|
h_list_ba.append(h_ba) |
|
|
|
|
|
|
|
if self.JK == "last": |
|
node_representation = h_list[-1] |
|
edge_representation = h_list_ba[-1] |
|
elif self.JK == "sum": |
|
node_representation = 0 |
|
edge_representation = 0 |
|
for layer in range(self.num_layers + 1): |
|
node_representation += h_list[layer] |
|
edge_representation += h_list_ba[layer] |
|
|
|
return node_representation,edge_representation |
|
if Use_geometry_enhanced==False: |
|
for layer in range(self.num_layers): |
|
h = self.convs[layer](h_list[layer], edge_index, |
|
self.convs_bond_embeding[layer](edge_attr[:, 0:len(bond_id_names)].to(torch.int64)) + |
|
self.convs_bond_float[layer]( |
|
edge_attr[:, len(bond_id_names):edge_attr.shape[1] + 1].to(torch.float32))) |
|
h = self.batch_norms[layer](h) |
|
if layer == self.num_layers - 1: |
|
|
|
h = F.dropout(h, self.drop_ratio, training=self.training) |
|
else: |
|
h = F.dropout(F.relu(h), self.drop_ratio, training=self.training) |
|
|
|
if self.residual: |
|
h += h_list[layer] |
|
|
|
h_list.append(h) |
|
|
|
|
|
if self.JK == "last": |
|
node_representation = h_list[-1] |
|
elif self.JK == "sum": |
|
node_representation = 0 |
|
for layer in range(self.num_layers + 1): |
|
node_representation += h_list[layer] |
|
|
|
return node_representation |
|
|
|
class GINGraphPooling(nn.Module): |
|
|
|
def __init__(self, num_tasks=1, num_layers=5, emb_dim=300, residual=False, drop_ratio=0, JK="last", graph_pooling="attention", |
|
descriptor_dim=1781): |
|
"""GIN Graph Pooling Module |
|
|
|
此模块首先采用GINNodeEmbedding模块对图上每一个节点做嵌入,然后对节点嵌入做池化得到图的嵌入,最后用一层线性变换得到图的最终的表示(graph representation)。 |
|
|
|
Args: |
|
num_tasks (int, optional): number of labels to be predicted. Defaults to 1 (控制了图表示的维度,dimension of graph representation). |
|
num_layers (int, optional): number of GINConv layers. Defaults to 5. |
|
emb_dim (int, optional): dimension of node embedding. Defaults to 300. |
|
residual (bool, optional): adding residual connection or not. Defaults to False. |
|
drop_ratio (float, optional): dropout rate. Defaults to 0. |
|
JK (str, optional): 可选的值为"last"和"sum"。选"last",只取最后一层的结点的嵌入,选"sum"对各层的结点的嵌入求和。Defaults to "last". |
|
graph_pooling (str, optional): pooling method of node embedding. 可选的值为"sum","mean","max","attention"和"set2set"。 Defaults to "sum". |
|
|
|
Out: |
|
graph representation |
|
""" |
|
super(GINGraphPooling, self).__init__() |
|
|
|
self.num_layers = num_layers |
|
self.drop_ratio = drop_ratio |
|
self.JK = JK |
|
self.emb_dim = emb_dim |
|
self.num_tasks = num_tasks |
|
self.descriptor_dim=descriptor_dim |
|
if self.num_layers < 2: |
|
raise ValueError("Number of GNN layers must be greater than 1.") |
|
|
|
self.gnn_node = GINNodeEmbedding(num_layers, emb_dim, JK=JK, drop_ratio=drop_ratio, residual=residual) |
|
|
|
|
|
if graph_pooling == "sum": |
|
self.pool = global_add_pool |
|
|
|
elif graph_pooling == "mean": |
|
self.pool = global_mean_pool |
|
|
|
elif graph_pooling == "max": |
|
self.pool = global_max_pool |
|
|
|
elif graph_pooling == "attention": |
|
self.pool = GlobalAttention(gate_nn=nn.Sequential( |
|
nn.Linear(emb_dim, emb_dim), nn.BatchNorm1d(emb_dim), nn.ReLU(), nn.Linear(emb_dim, 1))) |
|
|
|
|
|
elif graph_pooling == "set2set": |
|
self.pool = Set2Set(emb_dim, processing_steps=2) |
|
else: |
|
raise ValueError("Invalid graph pooling type.") |
|
|
|
if graph_pooling == "set2set": |
|
self.graph_pred_linear = nn.Linear(self.emb_dim, self.num_tasks) |
|
else: |
|
self.graph_pred_linear = nn.Linear(self.emb_dim, self.num_tasks) |
|
|
|
self.NN_descriptor = nn.Sequential(nn.Linear(self.descriptor_dim, self.emb_dim), |
|
nn.Sigmoid(), |
|
nn.Linear(self.emb_dim, self.emb_dim)) |
|
|
|
self.sigmoid = nn.Sigmoid() |
|
|
|
def forward(self, batched_atom_bond,batched_bond_angle): |
|
if Use_geometry_enhanced==True: |
|
h_node,h_node_ba= self.gnn_node(batched_atom_bond,batched_bond_angle) |
|
else: |
|
h_node= self.gnn_node(batched_atom_bond, batched_bond_angle) |
|
h_graph = self.pool(h_node, batched_atom_bond.batch) |
|
output = self.graph_pred_linear(h_graph) |
|
if self.training: |
|
return output,h_graph |
|
else: |
|
|
|
return torch.clamp(output, min=0, max=1e8),h_graph |
|
|
|
def mord(mol, nBits=1826, errors_as_zeros=True): |
|
try: |
|
result = calc(mol) |
|
desc_list = [r if not is_missing(r) else 0 for r in result] |
|
np_arr = np.array(desc_list) |
|
return np_arr |
|
except: |
|
return np.NaN if not errors_as_zeros else np.zeros((nBits,), dtype=np.float32) |
|
|
|
def load_3D_mol(): |
|
dir = 'mol_save/' |
|
for root, dirs, files in os.walk(dir): |
|
file_names = files |
|
file_names.sort(key=lambda x: int(x[x.find('_') + 5:x.find(".")])) |
|
mol_save = [] |
|
for file_name in file_names: |
|
mol_save.append(Chem.MolFromMolFile(dir + file_name)) |
|
return mol_save |
|
|
|
def parse_args(): |
|
parser = argparse.ArgumentParser(description='Graph data miming with GNN') |
|
parser.add_argument('--task_name', type=str, default='GINGraphPooling', |
|
help='task name') |
|
parser.add_argument('--device', type=int, default=0, |
|
help='which gpu to use if any (default: 0)') |
|
parser.add_argument('--num_layers', type=int, default=5, |
|
help='number of GNN message passing layers (default: 5)') |
|
parser.add_argument('--graph_pooling', type=str, default='sum', |
|
help='graph pooling strategy mean or sum (default: sum)') |
|
parser.add_argument('--emb_dim', type=int, default=128, |
|
help='dimensionality of hidden units in GNNs (default: 256)') |
|
parser.add_argument('--drop_ratio', type=float, default=0., |
|
help='dropout ratio (default: 0.)') |
|
parser.add_argument('--save_test', action='store_true') |
|
parser.add_argument('--batch_size', type=int, default=2048, |
|
help='input batch size for training (default: 512)') |
|
parser.add_argument('--epochs', type=int, default=1000, |
|
help='number of epochs to train (default: 100)') |
|
parser.add_argument('--weight_decay', type=float, default=0.00001, |
|
help='weight decay') |
|
parser.add_argument('--early_stop', type=int, default=10, |
|
help='early stop (default: 10)') |
|
parser.add_argument('--num_workers', type=int, default=0, |
|
help='number of workers (default: 0)') |
|
parser.add_argument('--dataset_root', type=str, default="dataset", |
|
help='dataset root') |
|
args = parser.parse_args() |
|
|
|
return args |
|
|
|
def calc_dragon_type_desc(mol): |
|
compound_mol = mol |
|
compound_MolWt = Descriptors.ExactMolWt(compound_mol) |
|
compound_TPSA = Chem.rdMolDescriptors.CalcTPSA(compound_mol) |
|
compound_nRotB = Descriptors.NumRotatableBonds(compound_mol) |
|
compound_HBD = Descriptors.NumHDonors(compound_mol) |
|
compound_HBA = Descriptors.NumHAcceptors(compound_mol) |
|
compound_LogP = Descriptors.MolLogP(compound_mol) |
|
return rdMolDescriptors.CalcAUTOCORR3D(mol) + rdMolDescriptors.CalcMORSE(mol) + \ |
|
rdMolDescriptors.CalcRDF(mol) + rdMolDescriptors.CalcWHIM(mol) + \ |
|
[compound_MolWt, compound_TPSA, compound_nRotB, compound_HBD, compound_HBA, compound_LogP] |
|
|
|
|
|
def eval(model, device, loader_atom_bond,loader_bond_angle): |
|
model.eval() |
|
y_true = [] |
|
y_pred = [] |
|
y_pred_10=[] |
|
y_pred_90=[] |
|
|
|
with torch.no_grad(): |
|
for _, batch in enumerate(zip(loader_atom_bond,loader_bond_angle)): |
|
batch_atom_bond = batch[0] |
|
batch_bond_angle = batch[1] |
|
batch_atom_bond = batch_atom_bond.to(device) |
|
batch_bond_angle = batch_bond_angle.to(device) |
|
pred = model(batch_atom_bond,batch_bond_angle)[0] |
|
|
|
y_true.append(batch_atom_bond.y.detach().cpu().reshape(-1)) |
|
y_pred.append(pred[:,1].detach().cpu()) |
|
y_pred_10.append(pred[:,0].detach().cpu()) |
|
y_pred_90.append(pred[:,2].detach().cpu()) |
|
y_true = torch.cat(y_true, dim=0) |
|
y_pred = torch.cat(y_pred, dim=0) |
|
y_pred_10 = torch.cat(y_pred_10, dim=0) |
|
y_pred_90 = torch.cat(y_pred_90, dim=0) |
|
|
|
|
|
|
|
|
|
|
|
input_dict = {"y_true": y_true, "y_pred": y_pred} |
|
return torch.mean((y_true - y_pred) ** 2).data.numpy() |
|
|
|
|
|
def cal_prob(prediction): |
|
''' |
|
calculate the separation probability Sp |
|
''' |
|
|
|
|
|
a=prediction[0][0] |
|
b=prediction[1][0] |
|
if a[2]<b[0]: |
|
return 1 |
|
elif a[0]>b[2]: |
|
return 1 |
|
else: |
|
length=min(a[2],b[2])-max(a[0],b[0]) |
|
all=max(a[2],b[2])-min(a[0],b[0]) |
|
return 1-length/(all) |
|
|
|
|
|
|
|
args = parse_args() |
|
nn_params = { |
|
'num_tasks': 3, |
|
'num_layers': args.num_layers, |
|
'emb_dim': args.emb_dim, |
|
'drop_ratio': args.drop_ratio, |
|
'graph_pooling': args.graph_pooling, |
|
'descriptor_dim': 1827 |
|
} |
|
device ='cpu' |
|
model = GINGraphPooling(**nn_params).to(device) |
|
|
|
|
|
''' |
|
Given two compounds and predict the RT in different condition |
|
''' |
|
|
|
|
|
def predict_separate(smile_1, smile_2, input_eluent, input_speed, input_column): |
|
if input_speed==None: |
|
out_put='Please input Speed!' |
|
return out_put |
|
if input_speed==0: |
|
out_put='Speed cannot be 0!' |
|
return out_put |
|
if input_eluent==None: |
|
out_put='Please input eluent!' |
|
return out_put |
|
|
|
speed = [] |
|
eluent = [] |
|
smiles=[smile_1,smile_2] |
|
for i in range(2): |
|
speed.append(input_speed) |
|
eluent.append(input_eluent) |
|
model.load_state_dict( |
|
torch.load(f'GeoGNN_model.pth',map_location=torch.device('cpu')),strict=False) |
|
model.eval() |
|
column_descriptor = np.load('column_descriptor.npy', allow_pickle=True) |
|
predict_column=input_column |
|
col_specify = column_specify[predict_column] |
|
col_des = np.array(column_descriptor[col_specify[3]]) |
|
mols = [] |
|
y_pred = [] |
|
all_descriptor = [] |
|
dataset = [] |
|
for smile in smiles: |
|
mol = Chem.MolFromSmiles(smile) |
|
mols.append(mol) |
|
for smile in smiles: |
|
mol = obtain_3D_mol(smile, 'conform') |
|
mol = Chem.MolFromMolFile(f"conform.mol") |
|
all_descriptor.append(mord(mol)) |
|
dataset.append(mol_to_geognn_graph_data_MMFF3d(mol)) |
|
|
|
for i in range(0, len(dataset)): |
|
data = dataset[i] |
|
atom_feature = [] |
|
bond_feature = [] |
|
for name in atom_id_names: |
|
atom_feature.append(data[name]) |
|
for name in bond_id_names[0:3]: |
|
bond_feature.append(data[name]) |
|
atom_feature = torch.from_numpy(np.array(atom_feature).T).to(torch.int64) |
|
bond_feature = torch.from_numpy(np.array(bond_feature).T).to(torch.int64) |
|
bond_float_feature = torch.from_numpy(data['bond_length'].astype(np.float32)) |
|
bond_angle_feature = torch.from_numpy(data['bond_angle'].astype(np.float32)) |
|
y = torch.Tensor([float(speed[i])]) |
|
edge_index = torch.from_numpy(data['edges'].T).to(torch.int64) |
|
bond_index = torch.from_numpy(data['BondAngleGraph_edges'].T).to(torch.int64) |
|
|
|
prop = torch.ones([bond_feature.shape[0]]) * eluent[i] |
|
coated = torch.ones([bond_feature.shape[0]]) * col_specify[0] |
|
diameter = torch.ones([bond_feature.shape[0]]) * col_specify[1] |
|
immobilized = torch.ones([bond_feature.shape[0]]) * col_specify[2] |
|
|
|
TPSA = torch.ones([bond_angle_feature.shape[0]]) * all_descriptor[i][820] / 100 |
|
RASA = torch.ones([bond_angle_feature.shape[0]]) * all_descriptor[i][821] |
|
RPSA = torch.ones([bond_angle_feature.shape[0]]) * all_descriptor[i][822] |
|
MDEC = torch.ones([bond_angle_feature.shape[0]]) * all_descriptor[i][1568] |
|
MATS = torch.ones([bond_angle_feature.shape[0]]) * all_descriptor[i][457] |
|
|
|
col_TPSA = torch.ones([bond_angle_feature.shape[0]]) * col_des[820] / 100 |
|
col_RASA = torch.ones([bond_angle_feature.shape[0]]) * col_des[821] |
|
col_RPSA = torch.ones([bond_angle_feature.shape[0]]) * col_des[822] |
|
col_MDEC = torch.ones([bond_angle_feature.shape[0]]) * col_des[1568] |
|
col_MATS = torch.ones([bond_angle_feature.shape[0]]) * col_des[457] |
|
|
|
bond_feature = torch.cat([bond_feature, coated.reshape(-1, 1)], dim=1) |
|
bond_feature = torch.cat([bond_feature, immobilized.reshape(-1, 1)], dim=1) |
|
bond_feature = torch.cat([bond_feature, bond_float_feature.reshape(-1, 1)], dim=1) |
|
bond_feature = torch.cat([bond_feature, prop.reshape(-1, 1)], dim=1) |
|
bond_feature = torch.cat([bond_feature, diameter.reshape(-1, 1)], dim=1) |
|
|
|
bond_angle_feature = torch.cat([bond_angle_feature.reshape(-1, 1), TPSA.reshape(-1, 1)], dim=1) |
|
bond_angle_feature = torch.cat([bond_angle_feature, RASA.reshape(-1, 1)], dim=1) |
|
bond_angle_feature = torch.cat([bond_angle_feature, RPSA.reshape(-1, 1)], dim=1) |
|
bond_angle_feature = torch.cat([bond_angle_feature, MDEC.reshape(-1, 1)], dim=1) |
|
bond_angle_feature = torch.cat([bond_angle_feature, MATS.reshape(-1, 1)], dim=1) |
|
bond_angle_feature = torch.cat([bond_angle_feature, col_TPSA.reshape(-1, 1)], dim=1) |
|
bond_angle_feature = torch.cat([bond_angle_feature, col_RASA.reshape(-1, 1)], dim=1) |
|
bond_angle_feature = torch.cat([bond_angle_feature, col_RPSA.reshape(-1, 1)], dim=1) |
|
bond_angle_feature = torch.cat([bond_angle_feature, col_MDEC.reshape(-1, 1)], dim=1) |
|
bond_angle_feature = torch.cat([bond_angle_feature, col_MATS.reshape(-1, 1)], dim=1) |
|
|
|
data_atom_bond = Data(atom_feature, edge_index, bond_feature, y) |
|
data_bond_angle = Data(edge_index=bond_index, edge_attr=bond_angle_feature) |
|
|
|
pred, h_graph = model(data_atom_bond.to(device), data_bond_angle.to(device)) |
|
|
|
y_pred.append(pred.detach().cpu().data.numpy() / speed[i]) |
|
|
|
else: |
|
Sp=cal_prob(y_pred) |
|
output_1=f'For smile_1,\n the predicted value is: {str(np.round(y_pred[0][0][1],3))}\n' |
|
output_2 = f'For smile_2,\n the predicted value is: {str(np.round(y_pred[1][0][1],3))}\n' |
|
output_3=f'The separation probability is: {str(np.round(Sp,3))}' |
|
out_put=output_1+output_2+output_3 |
|
return out_put |
|
|
|
|
|
def column_recommendation(smile_1, smile_2, input_eluent, input_speed): |
|
if input_speed==None: |
|
out_put='Please input Speed!' |
|
return out_put |
|
if input_speed==0: |
|
out_put='Speed cannot be 0!' |
|
return out_put |
|
if input_eluent==None: |
|
out_put='Please input eluent!' |
|
return out_put |
|
speed = [] |
|
eluent = [] |
|
Prediction = [] |
|
Sp = [] |
|
smiles = [smile_1, smile_2] |
|
for i in range(2): |
|
speed.append(input_speed) |
|
eluent.append(input_eluent) |
|
model.load_state_dict( |
|
torch.load(f'GeoGNN_model.pth',map_location=torch.device('cpu')),strict=False) |
|
model.eval() |
|
for predict_column in column_specify.keys(): |
|
column_descriptor = np.load('column_descriptor.npy', allow_pickle=True) |
|
col_specify = column_specify[predict_column] |
|
col_des = np.array(column_descriptor[col_specify[3]]) |
|
mols = [] |
|
y_pred = [] |
|
all_descriptor = [] |
|
dataset = [] |
|
for smile in smiles: |
|
mol = Chem.MolFromSmiles(smile) |
|
mols.append(mol) |
|
for smile in smiles: |
|
mol = obtain_3D_mol(smile, 'conform') |
|
mol = Chem.MolFromMolFile(f"conform.mol") |
|
all_descriptor.append(mord(mol)) |
|
dataset.append(mol_to_geognn_graph_data_MMFF3d(mol)) |
|
|
|
for i in range(0, len(dataset)): |
|
data = dataset[i] |
|
atom_feature = [] |
|
bond_feature = [] |
|
for name in atom_id_names: |
|
atom_feature.append(data[name]) |
|
for name in bond_id_names[0:3]: |
|
bond_feature.append(data[name]) |
|
atom_feature = torch.from_numpy(np.array(atom_feature).T).to(torch.int64) |
|
bond_feature = torch.from_numpy(np.array(bond_feature).T).to(torch.int64) |
|
bond_float_feature = torch.from_numpy(data['bond_length'].astype(np.float32)) |
|
bond_angle_feature = torch.from_numpy(data['bond_angle'].astype(np.float32)) |
|
y = torch.Tensor([float(speed[i])]) |
|
edge_index = torch.from_numpy(data['edges'].T).to(torch.int64) |
|
bond_index = torch.from_numpy(data['BondAngleGraph_edges'].T).to(torch.int64) |
|
|
|
prop = torch.ones([bond_feature.shape[0]]) * eluent[i] |
|
coated = torch.ones([bond_feature.shape[0]]) * col_specify[0] |
|
diameter = torch.ones([bond_feature.shape[0]]) * col_specify[1] |
|
immobilized = torch.ones([bond_feature.shape[0]]) * col_specify[2] |
|
|
|
TPSA = torch.ones([bond_angle_feature.shape[0]]) * all_descriptor[i][820] / 100 |
|
RASA = torch.ones([bond_angle_feature.shape[0]]) * all_descriptor[i][821] |
|
RPSA = torch.ones([bond_angle_feature.shape[0]]) * all_descriptor[i][822] |
|
MDEC = torch.ones([bond_angle_feature.shape[0]]) * all_descriptor[i][1568] |
|
MATS = torch.ones([bond_angle_feature.shape[0]]) * all_descriptor[i][457] |
|
|
|
col_TPSA = torch.ones([bond_angle_feature.shape[0]]) * col_des[820] / 100 |
|
col_RASA = torch.ones([bond_angle_feature.shape[0]]) * col_des[821] |
|
col_RPSA = torch.ones([bond_angle_feature.shape[0]]) * col_des[822] |
|
col_MDEC = torch.ones([bond_angle_feature.shape[0]]) * col_des[1568] |
|
col_MATS = torch.ones([bond_angle_feature.shape[0]]) * col_des[457] |
|
|
|
bond_feature = torch.cat([bond_feature, coated.reshape(-1, 1)], dim=1) |
|
bond_feature = torch.cat([bond_feature, immobilized.reshape(-1, 1)], dim=1) |
|
bond_feature = torch.cat([bond_feature, bond_float_feature.reshape(-1, 1)], dim=1) |
|
bond_feature = torch.cat([bond_feature, prop.reshape(-1, 1)], dim=1) |
|
bond_feature = torch.cat([bond_feature, diameter.reshape(-1, 1)], dim=1) |
|
|
|
bond_angle_feature = torch.cat([bond_angle_feature.reshape(-1, 1), TPSA.reshape(-1, 1)], dim=1) |
|
bond_angle_feature = torch.cat([bond_angle_feature, RASA.reshape(-1, 1)], dim=1) |
|
bond_angle_feature = torch.cat([bond_angle_feature, RPSA.reshape(-1, 1)], dim=1) |
|
bond_angle_feature = torch.cat([bond_angle_feature, MDEC.reshape(-1, 1)], dim=1) |
|
bond_angle_feature = torch.cat([bond_angle_feature, MATS.reshape(-1, 1)], dim=1) |
|
bond_angle_feature = torch.cat([bond_angle_feature, col_TPSA.reshape(-1, 1)], dim=1) |
|
bond_angle_feature = torch.cat([bond_angle_feature, col_RASA.reshape(-1, 1)], dim=1) |
|
bond_angle_feature = torch.cat([bond_angle_feature, col_RPSA.reshape(-1, 1)], dim=1) |
|
bond_angle_feature = torch.cat([bond_angle_feature, col_MDEC.reshape(-1, 1)], dim=1) |
|
bond_angle_feature = torch.cat([bond_angle_feature, col_MATS.reshape(-1, 1)], dim=1) |
|
|
|
data_atom_bond = Data(atom_feature, edge_index, bond_feature, y) |
|
data_bond_angle = Data(edge_index=bond_index, edge_attr=bond_angle_feature) |
|
|
|
pred, h_graph = model(data_atom_bond.to(device), data_bond_angle.to(device)) |
|
|
|
y_pred.append(pred.detach().cpu().data.numpy() / speed[i]) |
|
Prediction.append(y_pred) |
|
Sp.append(cal_prob(y_pred)) |
|
Prediction_1=np.squeeze(np.array(Prediction))[:,0,1] |
|
Prediction_2 = np.squeeze(np.array(Prediction))[:, 1, 1] |
|
Sp=np.array(Sp) |
|
result=pd.DataFrame({'Column_name':column_specify.keys(),'RT_1':Prediction_1,'RT_2':Prediction_2,'Separation_probability':Sp}) |
|
result= result[result.loc[:]!=0].dropna() |
|
result['RT_1'] = result['RT_1'].apply(lambda x: format(x, '.2f')) |
|
result['RT_2'] = result['RT_2'].apply(lambda x: format(x, '.2f')) |
|
result = result.sort_values(by="Separation_probability", ascending=False) |
|
result['Separation_probability'] = result['Separation_probability'].apply(lambda x: format(x, '.2%')) |
|
|
|
return result |
|
|
|
|
|
|
|
if __name__=='__main__': |
|
model_card = f""" |
|
## Description\n |
|
It is a app for predicting retention times in HPLC and recommend the best HPLC column type for chromatographic enantioseparation.\n\n |
|
Input:\n |
|
·smile_1 and smile 2: smiles of two molecules (especially enantiomers)\n |
|
·input_eluent: the ratio of eluent (hexane/2-propanol). For example: input 0.02 for hexane/2-propanol=98/02\n |
|
·input_spped: the flow rate of HPLC (mL/min)\n |
|
·column_name: select a column type in the dropdown\n |
|
Output:\n |
|
·The predicted retention time for two molecules |
|
·The separation probability (Sp) of two molecules, a higher Sp indicates that the molecules is easy to separate in HPLC under given condition (see Citation 1).\n |
|
## Citation\n |
|
We would appreciate it if you use our software and give us credit in the acknowledgements section of your paper:\n |
|
we use RF prediction software in our synthesis work. [Citation 1, Citation 2]\n |
|
Citation1: H. Xu, J. Lin, D. Zhang, F. Mo, Retention Time Prediction for Chromatographic Enantioseparation by Quantile Geometry-enhanced Graph Neural Network, arxiv:2211.03602\n |
|
Citation2: https://huggingface.co/spaces/woshixuhao/Chromatographic_Enantioseparation \n |
|
Business applications require authorization!\n |
|
## Function\n |
|
Single prediction: predict a compound under a given condition including eluent, flow rate and column type\n |
|
Column recommendation: give the separation probability of two molecules (especially enantiomers) under all column types\n |
|
""" |
|
demo_mark = gr.Blocks() |
|
|
|
with demo_mark: |
|
gr.Markdown(''' |
|
<div> |
|
<h1 style='text-align: center'>Chromatographic enantioseparation prediction</h1> |
|
</div> |
|
''') |
|
gr.Markdown(model_card) |
|
|
|
demo_1=gr.Interface(fn=predict_separate, inputs=["text", "text", "number", "number", |
|
gr.Dropdown(['ADH', 'ODH', 'IC', 'IA', 'OJH', 'ASH', 'IC3', |
|
'IE', 'ID', 'OD3', 'IB', 'AD', 'AD3', 'IF', 'OD', |
|
'AS', 'OJ3', 'IG', 'AZ', 'IAH', 'OJ', |
|
'ICH', 'OZ3', 'IF3', 'IAU'], label="Column type", |
|
info="Choose a HPLC column")], outputs=['text']) |
|
demo_2=gr.Interface(fn=column_recommendation, inputs=["text", "text", "number", "number"], |
|
outputs=['dataframe']) |
|
demo=gr.TabbedInterface([demo_mark,demo_1, demo_2], ["Markdown","Single prediction", "Column recommendation"]) |
|
demo.launch() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|