eduardosoares99's picture
Update finetune code (#8)
30e2e3e verified
# Deep learning
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader
from utils import CustomDataset, CustomDatasetMultitask, RMSELoss, normalize_smiles
# Data
import pandas as pd
import numpy as np
# Standard library
import random
import args
import os
import shutil
from tqdm import tqdm
# Machine Learning
from sklearn.metrics import mean_absolute_error, r2_score, accuracy_score, roc_auc_score, roc_curve, auc, precision_recall_curve
from scipy import stats
from utils import RMSE, sensitivity, specificity
class Trainer:
def __init__(self, raw_data, dataset_name, target, batch_size, hparams,
target_metric='rmse', seed=0, smi_ted_version=None, checkpoints_folder='./checkpoints', restart_filename=None, save_every_epoch=False, save_ckpt=True, device='cpu'):
# data
self.df_train = raw_data[0]
self.df_valid = raw_data[1]
self.df_test = raw_data[2]
self.dataset_name = dataset_name
self.target = target
self.batch_size = batch_size
self.hparams = hparams
self._prepare_data()
# config
self.target_metric = target_metric
self.seed = seed
self.smi_ted_version = smi_ted_version
self.checkpoints_folder = checkpoints_folder
self.restart_filename = restart_filename
self.start_epoch = 1
self.save_every_epoch = save_every_epoch
self.save_ckpt = save_ckpt
self.device = device
self.best_vloss = float('inf')
self.last_filename = None
self._set_seed(seed)
def _prepare_data(self):
# normalize dataset
self.df_train['canon_smiles'] = self.df_train['smiles'].apply(normalize_smiles)
self.df_valid['canon_smiles'] = self.df_valid['smiles'].apply(normalize_smiles)
self.df_test['canon_smiles'] = self.df_test['smiles'].apply(normalize_smiles)
self.df_train = self.df_train.dropna(subset=['canon_smiles'])
self.df_valid = self.df_valid.dropna(subset=['canon_smiles'])
self.df_test = self.df_test.dropna(subset=['canon_smiles'])
# create dataloader
self.train_loader = DataLoader(
CustomDataset(self.df_train, self.target),
batch_size=self.batch_size,
shuffle=True,
pin_memory=True
)
self.valid_loader = DataLoader(
CustomDataset(self.df_valid, self.target),
batch_size=self.batch_size,
shuffle=False,
pin_memory=True
)
self.test_loader = DataLoader(
CustomDataset(self.df_test, self.target),
batch_size=self.batch_size,
shuffle=False,
pin_memory=True
)
def compile(self, model, optimizer, loss_fn):
self.model = model
self.optimizer = optimizer
self.loss_fn = loss_fn
self._print_configuration()
if self.restart_filename:
self._load_checkpoint(self.restart_filename)
print('Checkpoint restored!')
def fit(self, max_epochs=500):
for epoch in range(self.start_epoch, max_epochs+1):
print(f'\n=====Epoch [{epoch}/{max_epochs}]=====')
# training
self.model.to(self.device)
self.model.train()
train_loss = self._train_one_epoch()
# validation
self.model.eval()
val_preds, val_loss, val_metrics = self._validate_one_epoch(self.valid_loader)
for m in val_metrics.keys():
print(f"[VALID] Evaluation {m.upper()}: {round(val_metrics[m], 4)}")
############################### Save Finetune checkpoint #######################################
if ((val_loss < self.best_vloss) or self.save_every_epoch) and self.save_ckpt:
# remove old checkpoint
if (self.last_filename != None) and (not self.save_every_epoch):
os.remove(os.path.join(self.checkpoints_folder, self.last_filename))
# filename
model_name = f'{str(self.model)}-Finetune'
self.last_filename = f"{model_name}_seed{self.seed}_{self.dataset_name}_epoch={epoch}_valloss={round(val_loss, 4)}.pt"
# update best loss
self.best_vloss = val_loss
# save checkpoint
print('Saving checkpoint...')
self._save_checkpoint(epoch, self.last_filename)
def evaluate(self, verbose=True):
if verbose:
print("\n=====Test Evaluation=====")
if self.smi_ted_version == 'v1':
import smi_ted_light.load as load
elif self.smi_ted_version == 'v2':
import smi_ted_large.load as load
else:
raise Exception('Please, specify the SMI-TED version: `v1` or `v2`.')
# copy vocabulary to checkpoint folder
if not os.path.exists(os.path.join(self.checkpoints_folder, 'bert_vocab_curated.txt')):
smi_ted_path = os.path.dirname(load.__file__)
shutil.copy(os.path.join(smi_ted_path, 'bert_vocab_curated.txt'), self.checkpoints_folder)
# load model for inference
model_inf = load.load_smi_ted(
folder=self.checkpoints_folder,
ckpt_filename=self.last_filename,
eval=True,
).to(self.device)
# set model evaluation mode
model_inf.eval()
# evaluate on test set
tst_preds, tst_loss, tst_metrics = self._validate_one_epoch(self.test_loader, model_inf)
if verbose:
# show metrics
for m in tst_metrics.keys():
print(f"[TEST] Evaluation {m.upper()}: {round(tst_metrics[m], 4)}")
# save predictions
pd.DataFrame(tst_preds).to_csv(
os.path.join(
self.checkpoints_folder,
f'{self.dataset_name}_{self.target if isinstance(self.target, str) else self.target[0]}_predict_test_seed{self.seed}.csv'),
index=False
)
def _train_one_epoch(self):
raise NotImplementedError
def _validate_one_epoch(self, data_loader, model=None):
raise NotImplementedError
def _print_configuration(self):
print('----Finetune information----')
print('Dataset:\t', self.dataset_name)
print('Target:\t\t', self.target)
print('Batch size:\t', self.batch_size)
print('LR:\t\t', self._get_lr())
print('Device:\t\t', self.device)
print('Optimizer:\t', self.optimizer.__class__.__name__)
print('Loss function:\t', self.loss_fn.__class__.__name__)
print('Seed:\t\t', self.seed)
print('Train size:\t', self.df_train.shape[0])
print('Valid size:\t', self.df_valid.shape[0])
print('Test size:\t', self.df_test.shape[0])
def _load_checkpoint(self, filename):
ckpt_path = os.path.join(self.checkpoints_folder, filename)
ckpt_dict = torch.load(ckpt_path, map_location='cpu')
self.model.load_state_dict(ckpt_dict['MODEL_STATE'])
self.start_epoch = ckpt_dict['EPOCHS_RUN'] + 1
self.best_vloss = ckpt_dict['finetune_info']['best_vloss']
def _save_checkpoint(self, current_epoch, filename):
if not os.path.exists(self.checkpoints_folder):
os.makedirs(self.checkpoints_folder)
ckpt_dict = {
'MODEL_STATE': self.model.state_dict(),
'EPOCHS_RUN': current_epoch,
'hparams': vars(self.hparams),
'finetune_info': {
'dataset': self.dataset_name,
'target`': self.target,
'batch_size': self.batch_size,
'lr': self._get_lr(),
'device': self.device,
'optim': self.optimizer.__class__.__name__,
'loss_fn': self.loss_fn.__class__.__name__,
'train_size': self.df_train.shape[0],
'valid_size': self.df_valid.shape[0],
'test_size': self.df_test.shape[0],
'best_vloss': self.best_vloss,
},
'seed': self.seed,
}
assert list(ckpt_dict.keys()) == ['MODEL_STATE', 'EPOCHS_RUN', 'hparams', 'finetune_info', 'seed']
torch.save(ckpt_dict, os.path.join(self.checkpoints_folder, filename))
def _set_seed(self, value):
random.seed(value)
torch.manual_seed(value)
np.random.seed(value)
if torch.cuda.is_available():
torch.cuda.manual_seed(value)
torch.cuda.manual_seed_all(value)
cudnn.deterministic = True
cudnn.benchmark = False
def _get_lr(self):
for param_group in self.optimizer.param_groups:
return param_group['lr']
class TrainerRegressor(Trainer):
def __init__(self, raw_data, dataset_name, target, batch_size, hparams,
target_metric='rmse', seed=0, smi_ted_version=None, checkpoints_folder='./checkpoints', restart_filename=None, save_every_epoch=False, save_ckpt=True, device='cpu'):
super().__init__(raw_data, dataset_name, target, batch_size, hparams,
target_metric, seed, smi_ted_version, checkpoints_folder, restart_filename, save_every_epoch, save_ckpt, device)
def _train_one_epoch(self):
running_loss = 0.0
for idx, data in enumerate(pbar := tqdm(self.train_loader)):
# Every data instance is an input + label pair
smiles, targets = data
targets = targets.clone().detach().to(self.device)
# zero the parameter gradients (otherwise they are accumulated)
self.optimizer.zero_grad()
# Make predictions for this batch
embeddings = self.model.extract_embeddings(smiles).to(self.device)
outputs = self.model.net(embeddings).squeeze()
# Compute the loss and its gradients
loss = self.loss_fn(outputs, targets)
loss.backward()
# Adjust learning weights
self.optimizer.step()
# print statistics
running_loss += loss.item()
# progress bar
pbar.set_description('[TRAINING]')
pbar.set_postfix(loss=running_loss/(idx+1))
pbar.refresh()
return running_loss / len(self.train_loader)
def _validate_one_epoch(self, data_loader, model=None):
data_targets = []
data_preds = []
running_loss = 0.0
model = self.model if model is None else model
with torch.no_grad():
for idx, data in enumerate(pbar := tqdm(data_loader)):
# Every data instance is an input + label pair
smiles, targets = data
targets = targets.clone().detach().to(self.device)
# Make predictions for this batch
embeddings = model.extract_embeddings(smiles).to(self.device)
predictions = model.net(embeddings).squeeze()
# Compute the loss
loss = self.loss_fn(predictions, targets)
data_targets.append(targets.view(-1))
data_preds.append(predictions.view(-1))
# print statistics
running_loss += loss.item()
# progress bar
pbar.set_description('[EVALUATION]')
pbar.set_postfix(loss=running_loss/(idx+1))
pbar.refresh()
# Put together predictions and labels from batches
preds = torch.cat(data_preds, dim=0).cpu().numpy()
tgts = torch.cat(data_targets, dim=0).cpu().numpy()
# Compute metrics
mae = mean_absolute_error(tgts, preds)
r2 = r2_score(tgts, preds)
rmse = RMSE(preds, tgts)
spearman = stats.spearmanr(tgts, preds).statistic # scipy 1.12.0
# Rearange metrics
metrics = {
'mae': mae,
'r2': r2,
'rmse': rmse,
'spearman': spearman,
}
return preds, running_loss / len(data_loader), metrics
class TrainerClassifier(Trainer):
def __init__(self, raw_data, dataset_name, target, batch_size, hparams,
target_metric='roc-auc', seed=0, smi_ted_version=None, checkpoints_folder='./checkpoints', restart_filename=None, save_every_epoch=False, save_ckpt=True, device='cpu'):
super().__init__(raw_data, dataset_name, target, batch_size, hparams,
target_metric, seed, smi_ted_version, checkpoints_folder, restart_filename, save_every_epoch, save_ckpt, device)
def _train_one_epoch(self):
running_loss = 0.0
for idx, data in enumerate(pbar := tqdm(self.train_loader)):
# Every data instance is an input + label pair
smiles, targets = data
targets = targets.clone().detach().to(self.device)
# zero the parameter gradients (otherwise they are accumulated)
self.optimizer.zero_grad()
# Make predictions for this batch
embeddings = self.model.extract_embeddings(smiles).to(self.device)
outputs = self.model.net(embeddings).squeeze()
# Compute the loss and its gradients
loss = self.loss_fn(outputs, targets.long())
loss.backward()
# Adjust learning weights
self.optimizer.step()
# print statistics
running_loss += loss.item()
# progress bar
pbar.set_description('[TRAINING]')
pbar.set_postfix(loss=running_loss/(idx+1))
pbar.refresh()
return running_loss / len(self.train_loader)
def _validate_one_epoch(self, data_loader, model=None):
data_targets = []
data_preds = []
running_loss = 0.0
model = self.model if model is None else model
with torch.no_grad():
for idx, data in enumerate(pbar := tqdm(data_loader)):
# Every data instance is an input + label pair
smiles, targets = data
targets = targets.clone().detach().to(self.device)
# Make predictions for this batch
embeddings = model.extract_embeddings(smiles).to(self.device)
predictions = model.net(embeddings).squeeze()
# Compute the loss
loss = self.loss_fn(predictions, targets.long())
data_targets.append(targets.view(-1))
data_preds.append(predictions)
# print statistics
running_loss += loss.item()
# progress bar
pbar.set_description('[EVALUATION]')
pbar.set_postfix(loss=running_loss/(idx+1))
pbar.refresh()
# Put together predictions and labels from batches
preds = torch.cat(data_preds, dim=0).cpu().numpy()
tgts = torch.cat(data_targets, dim=0).cpu().numpy()
# Compute metrics
preds_cpu = F.softmax(torch.tensor(preds), dim=1).cpu().numpy()[:, 1]
# accuracy
y_pred = np.where(preds_cpu >= 0.5, 1, 0)
accuracy = accuracy_score(tgts, y_pred)
# sensitivity
sn = sensitivity(tgts, y_pred)
# specificity
sp = specificity(tgts, y_pred)
# roc-auc
fpr, tpr, _ = roc_curve(tgts, preds_cpu)
roc_auc = auc(fpr, tpr)
# prc-auc
precision, recall, _ = precision_recall_curve(tgts, preds_cpu)
prc_auc = auc(recall, precision)
# Rearange metrics
metrics = {
'acc': accuracy,
'roc-auc': roc_auc,
'prc-auc': prc_auc,
'sensitivity': sn,
'specificity': sp,
}
return preds, running_loss / len(data_loader), metrics
class TrainerClassifierMultitask(Trainer):
def __init__(self, raw_data, dataset_name, target, batch_size, hparams,
target_metric='roc-auc', seed=0, smi_ted_version=None, checkpoints_folder='./checkpoints', restart_filename=None, save_every_epoch=False, save_ckpt=True, device='cpu'):
super().__init__(raw_data, dataset_name, target, batch_size, hparams,
target_metric, seed, smi_ted_version, checkpoints_folder, restart_filename, save_every_epoch, save_ckpt, device)
def _prepare_data(self):
# normalize dataset
self.df_train['canon_smiles'] = self.df_train['smiles'].apply(normalize_smiles)
self.df_valid['canon_smiles'] = self.df_valid['smiles'].apply(normalize_smiles)
self.df_test['canon_smiles'] = self.df_test['smiles'].apply(normalize_smiles)
self.df_train = self.df_train.dropna(subset=['canon_smiles'])
self.df_valid = self.df_valid.dropna(subset=['canon_smiles'])
self.df_test = self.df_test.dropna(subset=['canon_smiles'])
# create dataloader
self.train_loader = DataLoader(
CustomDatasetMultitask(self.df_train, self.target),
batch_size=self.batch_size,
shuffle=True,
pin_memory=True
)
self.valid_loader = DataLoader(
CustomDatasetMultitask(self.df_valid, self.target),
batch_size=self.batch_size,
shuffle=False,
pin_memory=True
)
self.test_loader = DataLoader(
CustomDatasetMultitask(self.df_test, self.target),
batch_size=self.batch_size,
shuffle=False,
pin_memory=True
)
def _train_one_epoch(self):
running_loss = 0.0
for idx, data in enumerate(pbar := tqdm(self.train_loader)):
# Every data instance is an input + label pair + mask
smiles, targets, target_masks = data
targets = targets.clone().detach().to(self.device)
# zero the parameter gradients (otherwise they are accumulated)
self.optimizer.zero_grad()
# Make predictions for this batch
embeddings = self.model.extract_embeddings(smiles).to(self.device)
outputs = self.model.net(embeddings, multitask=True).squeeze()
outputs = outputs * target_masks.to(self.device)
# Compute the loss and its gradients
loss = self.loss_fn(outputs, targets)
loss.backward()
# Adjust learning weights
self.optimizer.step()
# print statistics
running_loss += loss.item()
# progress bar
pbar.set_description('[TRAINING]')
pbar.set_postfix(loss=running_loss/(idx+1))
pbar.refresh()
return running_loss / len(self.train_loader)
def _validate_one_epoch(self, data_loader, model=None):
data_targets = []
data_preds = []
data_masks = []
running_loss = 0.0
model = self.model if model is None else model
with torch.no_grad():
for idx, data in enumerate(pbar := tqdm(data_loader)):
# Every data instance is an input + label pair + mask
smiles, targets, target_masks = data
targets = targets.clone().detach().to(self.device)
# Make predictions for this batch
embeddings = model.extract_embeddings(smiles).to(self.device)
predictions = model.net(embeddings, multitask=True).squeeze()
predictions = predictions * target_masks.to(self.device)
# Compute the loss
loss = self.loss_fn(predictions, targets)
data_targets.append(targets)
data_preds.append(predictions)
data_masks.append(target_masks)
# print statistics
running_loss += loss.item()
# progress bar
pbar.set_description('[EVALUATION]')
pbar.set_postfix(loss=running_loss/(idx+1))
pbar.refresh()
# Put together predictions and labels from batches
preds = torch.cat(data_preds, dim=0)
tgts = torch.cat(data_targets, dim=0)
mask = torch.cat(data_masks, dim=0)
mask = mask > 0
# Compute metrics
roc_aucs = []
prc_aucs = []
sns = []
sps = []
num_tasks = len(self.target)
for idx in range(num_tasks):
actuals_task = torch.masked_select(tgts[:, idx], mask[:, idx].to(self.device))
preds_task = torch.masked_select(preds[:, idx], mask[:, idx].to(self.device))
# accuracy
y_pred = np.where(preds_task.cpu().detach() >= 0.5, 1, 0)
accuracy = accuracy_score(actuals_task.cpu().numpy(), y_pred)
# sensitivity
sn = sensitivity(actuals_task.cpu().numpy(), y_pred)
# specificity
sp = specificity(actuals_task.cpu().numpy(), y_pred)
# roc-auc
roc_auc = roc_auc_score(actuals_task.cpu().numpy(), preds_task.cpu().numpy())
# prc-auc
precision, recall, thresholds = precision_recall_curve(actuals_task.cpu().numpy(), preds_task.cpu().numpy())
prc_auc = auc(recall, precision)
# append
sns.append(sn)
sps.append(sp)
roc_aucs.append(roc_auc)
prc_aucs.append(prc_auc)
average_sn = torch.mean(torch.tensor(sns))
average_sp = torch.mean(torch.tensor(sps))
average_roc_auc = torch.mean(torch.tensor(roc_aucs))
average_prc_auc = torch.mean(torch.tensor(prc_aucs))
# Rearange metrics
metrics = {
'acc': accuracy,
'roc-auc': average_roc_auc.item(),
'prc-auc': average_prc_auc.item(),
'sensitivity': average_sn.item(),
'specificity': average_sp.item(),
}
return preds.cpu().numpy(), running_loss / len(data_loader), metrics