|
|
|
import torch
|
|
from torch.utils.data import Dataset
|
|
from sklearn.metrics import confusion_matrix
|
|
|
|
|
|
import pandas as pd
|
|
import numpy as np
|
|
|
|
|
|
import os
|
|
|
|
|
|
from rdkit import Chem
|
|
from rdkit.Chem import PandasTools
|
|
from rdkit.Chem import Descriptors
|
|
PandasTools.RenderImagesInAllDataFrames(True)
|
|
|
|
|
|
def normalize_smiles(smi, canonical=True, isomeric=False):
|
|
try:
|
|
normalized = Chem.MolToSmiles(
|
|
Chem.MolFromSmiles(smi), canonical=canonical, isomericSmiles=isomeric
|
|
)
|
|
except:
|
|
normalized = None
|
|
return normalized
|
|
|
|
|
|
class RMSELoss:
|
|
def __init__(self):
|
|
pass
|
|
|
|
def __call__(self, yhat, y):
|
|
return torch.sqrt(torch.mean((yhat-y)**2))
|
|
|
|
|
|
def RMSE(predictions, targets):
|
|
return np.sqrt(((predictions - targets) ** 2).mean())
|
|
|
|
|
|
def sensitivity(y_true, y_pred):
|
|
tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
|
|
return (tp/(tp+fn))
|
|
|
|
|
|
def specificity(y_true, y_pred):
|
|
tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
|
|
return (tn/(tn+fp))
|
|
|
|
|
|
def get_optim_groups(module, keep_decoder=False):
|
|
|
|
|
|
decay = set()
|
|
no_decay = set()
|
|
whitelist_weight_modules = (torch.nn.Linear,)
|
|
blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
|
|
for mn, m in module.named_modules():
|
|
for pn, p in m.named_parameters():
|
|
fpn = '%s.%s' % (mn, pn) if mn else pn
|
|
|
|
if not keep_decoder and 'decoder' in fpn:
|
|
continue
|
|
|
|
if pn.endswith('bias'):
|
|
|
|
no_decay.add(fpn)
|
|
elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
|
|
|
|
decay.add(fpn)
|
|
elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
|
|
|
|
no_decay.add(fpn)
|
|
|
|
|
|
param_dict = {pn: p for pn, p in module.named_parameters()}
|
|
|
|
|
|
optim_groups = [
|
|
{"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": 0.0},
|
|
{"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
|
|
]
|
|
|
|
return optim_groups
|
|
|
|
|
|
class CustomDataset(Dataset):
|
|
def __init__(self, dataset, target):
|
|
self.dataset = dataset
|
|
self.target = target
|
|
|
|
def __len__(self):
|
|
return len(self.dataset)
|
|
|
|
def __getitem__(self, idx):
|
|
smiles = self.dataset['canon_smiles'].iloc[idx]
|
|
labels = self.dataset[self.target].iloc[idx]
|
|
return smiles, labels
|
|
|
|
|
|
class CustomDatasetMultitask(Dataset):
|
|
def __init__(self, dataset, targets):
|
|
self.dataset = dataset
|
|
self.targets = targets
|
|
|
|
def __len__(self):
|
|
return len(self.dataset)
|
|
|
|
def __getitem__(self, idx):
|
|
smiles = self.dataset['canon_smiles'].iloc[idx]
|
|
labels = self.dataset[self.targets].iloc[idx].to_numpy()
|
|
mask = [0.0 if np.isnan(x) else 1.0 for x in labels]
|
|
labels = [0.0 if np.isnan(x) else x for x in labels]
|
|
return smiles, torch.tensor(labels, dtype=torch.float32), torch.tensor(mask) |