|
from models.MoE_ECGFormer import MoE_ECGFormer |
|
from data.dataloader import ECGDataloader |
|
from configs.data_configs import get_dataset_class |
|
from configs.hparams import get_hparams_class |
|
from utils import AverageMeter, to_device, _save_metrics, copy_files |
|
from utils import fix_randomness, starting_logs, save_checkpoint, _calc_metrics |
|
import torch |
|
import torch.nn.functional as F |
|
import datetime |
|
import os |
|
import collections |
|
import numpy as np |
|
|
|
import warnings |
|
import sklearn.exceptions |
|
|
|
warnings.filterwarnings("ignore", category=sklearn.exceptions.UndefinedMetricWarning) |
|
warnings.simplefilter(action='ignore', category=FutureWarning) |
|
|
|
|
|
class Trainer(object): |
|
def __init__(self, args): |
|
|
|
self.dataset = args.dataset |
|
self.seed_id = args.seed_id |
|
self.device = torch.device(args.device) |
|
|
|
|
|
self.run_description = f"{args.run_description}_{datetime.datetime.now().strftime('%H_%M')}" |
|
self.experiment_description = args.experiment_description |
|
|
|
|
|
self.home_path = os.getcwd() |
|
self.save_dir = os.path.join(os.getcwd(), "experiments_logs") |
|
self.exp_log_dir = os.path.join(self.save_dir, self.experiment_description, self.run_description) |
|
os.makedirs(self.exp_log_dir, exist_ok=True) |
|
|
|
self.data_path = args.data_path |
|
|
|
|
|
self.num_runs = args.num_runs |
|
|
|
|
|
self.dataset_configs, self.hparams_class = self.get_configs() |
|
|
|
|
|
self.hparams = self.hparams_class.train_params |
|
|
|
def get_configs(self): |
|
dataset_class = get_dataset_class(self.dataset) |
|
hparams_class = get_hparams_class("Supervised") |
|
return dataset_class(), hparams_class() |
|
|
|
def load_data(self, data_type): |
|
self.train_dl, self.cw_dict = ECGDataloader(self.data_path, data_type, self.hparams).train_dataloader() |
|
self.test_dl = ECGDataloader(self.data_path, data_type, self.hparams).test_dataloader() |
|
self.valid_dl = ECGDataloader(self.data_path, data_type, self.hparams).valid_dataloader() |
|
|
|
def calc_results_per_run(self): |
|
acc, f1 = _calc_metrics(self.pred_labels, self.true_labels, self.dataset_configs.class_names) |
|
return acc, f1 |
|
|
|
def train(self): |
|
copy_files(self.exp_log_dir) |
|
|
|
self.metrics = {'accuracy': [], 'f1_score': []} |
|
|
|
|
|
fix_randomness(int(self.seed_id)) |
|
|
|
|
|
self.logger, self.scenario_log_dir = starting_logs(self.dataset, self.exp_log_dir, self.seed_id) |
|
self.logger.debug(self.hparams) |
|
|
|
|
|
self.load_data(self.dataset) |
|
|
|
model = MoE_ECGFormer(configs=self.dataset_configs, hparams=self.hparams) |
|
model.to(self.device) |
|
|
|
|
|
loss_avg_meters = collections.defaultdict(lambda: AverageMeter()) |
|
|
|
self.optimizer = torch.optim.Adam( |
|
model.parameters(), |
|
lr=self.hparams["learning_rate"], |
|
weight_decay=self.hparams["weight_decay"], |
|
betas=(0.9, 0.99) |
|
) |
|
|
|
weights = [float(value) for value in self.cw_dict.values()] |
|
|
|
weights_array = np.array(weights).astype(np.float32) |
|
weights_tensor = torch.tensor(weights_array).to(self.device) |
|
self.cross_entropy = torch.nn.CrossEntropyLoss(weight=weights_tensor) |
|
|
|
best_acc = 0 |
|
best_f1 = 0 |
|
|
|
|
|
ts_acc = 0 |
|
ts_f1 = 0 |
|
for epoch in range(1, self.hparams["num_epochs"] + 1): |
|
model.train() |
|
|
|
for step, batches in enumerate(self.train_dl): |
|
batches = to_device(batches, self.device) |
|
|
|
data = batches['samples'].float() |
|
labels = batches['labels'].long() |
|
|
|
|
|
self.optimizer.zero_grad() |
|
|
|
|
|
logits = model(data) |
|
|
|
|
|
x_ent_loss = self.cross_entropy(logits, labels) |
|
|
|
x_ent_loss.backward() |
|
self.optimizer.step() |
|
|
|
losses = {'Total_loss': x_ent_loss.item()} |
|
for key, val in losses.items(): |
|
loss_avg_meters[key].update(val, self.hparams["batch_size"]) |
|
|
|
self.evaluate(model, self.valid_dl) |
|
tr_acc, tr_f1 = self.calc_results_per_run() |
|
|
|
self.logger.debug(f'[Epoch : {epoch}/{self.hparams["num_epochs"]}]') |
|
for key, val in loss_avg_meters.items(): |
|
self.logger.debug(f'{key}\t: {val.avg:2.4f}') |
|
self.logger.debug(f'TRAIN: Acc:{tr_acc:2.4f} \t F1:{tr_f1:2.4f}') |
|
|
|
|
|
self.evaluate(model, self.valid_dl) |
|
ts_acc, ts_f1 = self.calc_results_per_run() |
|
if ts_f1 > best_f1: |
|
best_f1 = ts_f1 |
|
best_acc = ts_acc |
|
save_checkpoint(self.exp_log_dir, model, self.dataset, self.dataset_configs, self.hparams, "best") |
|
_save_metrics(self.pred_labels, self.true_labels, self.exp_log_dir, "validation_best") |
|
|
|
|
|
self.logger.debug(f'VAL : Acc:{ts_acc:2.4f} \t F1:{ts_f1:2.4f} (best: {best_f1:2.4f})') |
|
self.logger.debug(f'-------------------------------------') |
|
|
|
|
|
_save_metrics(self.pred_labels, self.true_labels, self.exp_log_dir, "validation_last") |
|
self.logger.debug("LAST EPOCH PERFORMANCE on validation set...") |
|
self.logger.debug(f'Acc:{ts_acc:2.4f} \t F1:{ts_f1:2.4f}') |
|
|
|
self.logger.debug(":::::::::::::") |
|
|
|
self.logger.debug("BEST EPOCH PERFORMANCE on validation set ...") |
|
self.logger.debug(f'Acc:{best_acc:2.4f} \t F1:{best_f1:2.4f}') |
|
save_checkpoint(self.exp_log_dir, model, self.dataset, self.dataset_configs, self.hparams, "last") |
|
|
|
|
|
print(" === Evaluating on TEST set ===") |
|
self.evaluate(model, self.test_dl) |
|
test_acc, test_f1 = self.calc_results_per_run() |
|
_save_metrics(self.pred_labels, self.true_labels, self.exp_log_dir, "test_last") |
|
self.logger.debug(f'Acc:{test_acc:2.4f} \t F1:{test_f1:2.4f}') |
|
|
|
def evaluate(self, model, dataset): |
|
model.to(self.device).eval() |
|
|
|
total_loss_ = [] |
|
|
|
self.pred_labels = np.array([]) |
|
self.true_labels = np.array([]) |
|
|
|
with torch.no_grad(): |
|
for batches in dataset: |
|
batches = to_device(batches, self.device) |
|
data = batches['samples'].float() |
|
labels = batches['labels'].long() |
|
|
|
|
|
predictions = model(data) |
|
|
|
|
|
loss = F.cross_entropy(predictions, labels) |
|
total_loss_.append(loss.item()) |
|
pred = predictions.detach().argmax(dim=1) |
|
|
|
self.pred_labels = np.append(self.pred_labels, pred.cpu().numpy()) |
|
self.true_labels = np.append(self.true_labels, labels.data.cpu().numpy()) |
|
|
|
self.trg_loss = torch.tensor(total_loss_).mean() |
|
|